MCG-NJU/SparseBEV

why there is a re-implementation of gradient checkpointing

zen-d opened this issue · 3 comments

HI @afterthat97, thanks for your awesome work. I noticed there is a checkpoint.py file that seems to re-implement torch.utils.checkpoint. What is the motivation for that?

Furthermore, could you please provide your reference pytorch implementation version so that I could get a quick diff (since pytorch is rapidly developing, the specific implementation may vary from version to version)? Thanks.

This implementation is exactly the same as the official code of torch 2.0. So, if you are using torch2.0, you can directly use torch.utils.checkpoint.

However, for torch1.10 users, the official implementation does not support use_reentrant=True, so you may need this copy.

Hi, thanks for your prompt reply!

Let me try to figure out the overall logic:

  1. we want to use gradient checkpointing to save memory, but we may have some "unused parameters", namely we have to set find_unused_parameters=True
  2. Such behavior in 1 is not allowed in conventional PyTorch DDP (find_unused_parameters=False while using checkpointing), so we have to pass in use_reentrant=False to torch.utils.checkpoint.checkpoint as a fix.
  3. In case the user's Pytorch version is too low to support passing the argument in as in 2, a customized PyTorch checkpoint interface is necessary. That is how we come to the checkpoint.py.

Please correct me if there is any mistake in the above description.

Yes you are correct.

Our checkpoint.py is directly copied from https://pytorch.org/docs/2.0/_modules/torch/utils/checkpoint.html#checkpoint with no modifications