how to realize a batchnorm6d layer
hnczwj2008 opened this issue · 2 comments
hnczwj2008 commented
Hi ptrblck,
could you tell me how to realize an anbitrary dimension batchnorm layer, such as : batchnorm4d, batchnorm6d …
thank you!
ptrblck commented
The parameters (weight
and bias
) as well as the buffers (running_mean
and running_var
) would keep their shape as the number of input channels (defined by num_features
in the code).
Applying these parameters and buffers would need more unsqueezing to match the new number of dimensions.
E.g.
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
would have to be changed to
input = (input - mean[None, :, None, None, None, None]) / (torch.sqrt(var[None, :, None, None, None, None] + self.eps))
for the additional dimensions.
Also, since you cannot derive from nn.BatchNorm2d
you would either have to derive from e.g. nn.modules.batchnorm._BatchNorm
.
hnczwj2008 commented
Hi ptrblck,
Here are my BatchNorm6d code based on your source, and it works, thank you for your help.
class MyBatchNorm_6d(nn.modules.batchnorm._BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MyBatchNorm_6d, self).__init__(num_features, eps, momentum, affine, track_running_stats)
def _check_input_dim(self, input):
if input.dim() != 8:
raise ValueError('expected 8D input (got {}D input)'
.format(input.dim()))
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3, 4, 5, 6, 7])
# use biased var in train
var = input.var([0, 2, 3, 4, 5, 6, 7], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None, None, None, None, None]) / (torch.sqrt(var[None, :, None, None, None, None, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None, None, None, None, None] + self.bias[None, :, None, None, None, None, None, None]
return input