gpleiss/efficient_densenet_pytorch

The BN running mean&var with torch.utils.checkpoint.checkpoint

ljn114514 opened this issue · 2 comments

How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.

This is a good point. Ideally, PyTorch's batch norm layers should be smart enough to update the running mean/var appropriately with the checkpointing operation.

If this is not the case, then you should raise an issue with PyTorch, since the checkpointing/batch norm layers are part of their library, not this library.

Ok, thanks for your reply