Gradient Checkpointing
xrsrke opened this issue · 1 comments
xrsrke commented
- Selectively recompute the forward pass of some operations in the backward pass to save memory.
- Replace
transformers
's gradient checkpointing with pipegoose's gradient checkpointing.
APIs
import pipegoose.utils.checkpointing import Checkpointing
mlp = model.transformer.blocks[0].mlp
mlp = Checkpointing(mlp, parallel_context)
outputs = mlp(inputs)
Reading
- https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
- Reducing Activation Recomputation in Large Transformer Models [[link]](https://arxiv.org/abs/2205.05198)
Etelis commented
I will do it.
!assign