The BN running mean&var with torch.utils.checkpoint.checkpoint
ljn114514 opened this issue · 2 comments
ljn114514 commented
How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.
gpleiss commented
This is a good point. Ideally, PyTorch's batch norm layers should be smart enough to update the running mean/var appropriately with the checkpointing operation.
If this is not the case, then you should raise an issue with PyTorch, since the checkpointing/batch norm layers are part of their library, not this library.
ljn114514 commented
Ok, thanks for your reply