/gradient-checkpointing

Gradient checkpointing for graph mode in Tensorflow 2

Primary LanguagePython

Gradient checkpointing

Gradient checkpointing for graph mode execution in Tensorflow 2

This is a standalone version extracted from the original implementation in tf-slim.

If using eager execution, use tf.recompute_grad.

For more information on recomputing gradients between graph nodes during backpropagation, see the original gradient checkpointing repository.


Tested with tf-nightly==2.2.0.dev20200303 in graph mode on TPU.

Example usage for a model built with a Keras layer call method:

def call(self, x, past):
    @gradient_checkpointing.recompute_grad
    def inner(x):
        # ops go here
        return y
    return inner(x)

Note: Gradient checkpointing can significantly slow down training.