ELIFE-ASU/INNLab

Do we need to add `.detach()` after `var` in `INN.BatchNorm1d`?

Zhangyanbo opened this issue · 2 comments

In INN.BatchNorm1d, the forward function is:

def forward(self, x, log_p=0, log_det_J=0):
        
        if self.compute_p:
            if not self.training:
                # if in self.eval()
                var = self.running_var # [dim]
            else:
                # if in training
                # TODO: Do we need to add .detach() after var?
                var = torch.var(x, dim=0, unbiased=False) # [dim]

            x = super(BatchNorm1d, self).forward(x)

            log_det = -0.5 * torch.log(var + self.eps)
            log_det = torch.sum(log_det, dim=-1)

            return x, log_p, log_det_J + log_det
        else:
            return super(BatchNorm1d, self).forward(x)

Do we need to requires var has gradient information? It seems not training BatchNorm1d, but training modules before it. Is there any references on this?

Compare to nn.BatchNorm1d:

x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3, affine=False)

bn(x)

The output is:

tensor([[-1.6941,  0.2933, -0.2451],
        [-0.1313, -0.2711,  1.4740],
        [ 0.2754, -0.2282,  0.4445],
        [ 0.1287, -1.4409, -0.0721],
        [ 1.4213,  1.6469, -1.6014]])

So, if we do not require affine in bn, we don't need gradient for BatchNorm.

Experiments show that if we add .detach(), the training loss will not decrease. While if I added .detach(), it works. So, in the latest version, I added a parameter requires_grad:bool to INN.BatchNorm1d.