Do we need to add `.detach()` after `var` in `INN.BatchNorm1d`?
Zhangyanbo opened this issue · 2 comments
Zhangyanbo commented
In INN.BatchNorm1d
, the forward function is:
def forward(self, x, log_p=0, log_det_J=0):
if self.compute_p:
if not self.training:
# if in self.eval()
var = self.running_var # [dim]
else:
# if in training
# TODO: Do we need to add .detach() after var?
var = torch.var(x, dim=0, unbiased=False) # [dim]
x = super(BatchNorm1d, self).forward(x)
log_det = -0.5 * torch.log(var + self.eps)
log_det = torch.sum(log_det, dim=-1)
return x, log_p, log_det_J + log_det
else:
return super(BatchNorm1d, self).forward(x)
Do we need to requires var
has gradient information? It seems not training BatchNorm1d
, but training modules before it. Is there any references on this?
Zhangyanbo commented
Compare to nn.BatchNorm1d
:
x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3, affine=False)
bn(x)
The output is:
tensor([[-1.6941, 0.2933, -0.2451],
[-0.1313, -0.2711, 1.4740],
[ 0.2754, -0.2282, 0.4445],
[ 0.1287, -1.4409, -0.0721],
[ 1.4213, 1.6469, -1.6014]])
So, if we do not require affine in bn, we don't need gradient for BatchNorm
.
Zhangyanbo commented
Experiments show that if we add .detach()
, the training loss will not decrease. While if I added .detach()
, it works. So, in the latest version, I added a parameter requires_grad:bool
to INN.BatchNorm1d
.