ActNorm implementation missing division by `std` on the shift parameter
Opened this issue · 0 comments
hemildesai commented
Hi,
Thanks for making the video lectures and homework public. I'm really enjoying the course so far. I was going through homework 2 and wanted to compare my stuff with the solutions. For the solution of hw2, I found the following implementation of ActNorm
class ActNorm(nn.Module):
def __init__(self, n_channels):
super(ActNorm, self).__init__()
self.log_scale = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
self.shift = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
self.n_channels = n_channels
self.initialized = False
def forward(self, x, reverse=False):
if reverse:
return (x - self.shift) * torch.exp(-self.log_scale), self.log_scale
else:
if not self.initialized:
self.shift.data = -torch.mean(x, dim=[0, 2, 3], keepdim=True)
self.log_scale.data = - torch.log(
torch.std(x.permute(1, 0, 2, 3).reshape(self.n_channels, -1), dim=1).reshape(1, self.n_channels, 1,
1))
self.initialized = True
result = x * torch.exp(self.log_scale) + self.shift
return x * torch.exp(self.log_scale) + self.shift, self.log_scale
I think the shift
needs to be divided by the standard deviation as follows for the activations to be normalized.
self.shift.data = -(torch.mean(x, dim=[0, 2, 3], keepdim=True) * torch.exp(self.log_scale)
Let me know if I'm missing something.