xrsrke/pipegoose

Gradient Checkpointing

xrsrke opened this issue · 1 comments

xrsrke commented
  • Selectively recompute the forward pass of some operations in the backward pass to save memory.
  • Replace transformers's gradient checkpointing with pipegoose's gradient checkpointing.

APIs

import pipegoose.utils.checkpointing import Checkpointing

mlp = model.transformer.blocks[0].mlp
mlp = Checkpointing(mlp, parallel_context)

outputs = mlp(inputs)

Reading

Etelis commented

I will do it.
!assign