ptrblck/pytorch_misc

how to realize a batchnorm6d layer

hnczwj2008 opened this issue · 2 comments

Hi ptrblck,
could you tell me how to realize an anbitrary dimension batchnorm layer, such as : batchnorm4d, batchnorm6d …
thank you!

The parameters (weight and bias) as well as the buffers (running_mean and running_var) would keep their shape as the number of input channels (defined by num_features in the code).
Applying these parameters and buffers would need more unsqueezing to match the new number of dimensions.
E.g.

input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))

would have to be changed to

input = (input - mean[None, :, None, None, None, None]) / (torch.sqrt(var[None, :, None, None, None, None] + self.eps))

for the additional dimensions.

Also, since you cannot derive from nn.BatchNorm2d you would either have to derive from e.g. nn.modules.batchnorm._BatchNorm.

Hi ptrblck,
Here are my BatchNorm6d code based on your source, and it works, thank you for your help.

class MyBatchNorm_6d(nn.modules.batchnorm._BatchNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm_6d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
    
    def _check_input_dim(self, input):
        if input.dim() != 8:
            raise ValueError('expected 8D input (got {}D input)'
                             .format(input.dim()))
    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3, 4, 5, 6, 7])
            # use biased var in train
            var = input.var([0, 2, 3, 4, 5, 6, 7], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None, None, None, None, None]) / (torch.sqrt(var[None, :, None, None, None, None, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None, None, None, None, None] + self.bias[None, :, None, None, None, None, None, None]

        return input