tracel-ai/burn

Autodiff: checkpointing strategy

louisfd opened this issue · 4 comments

In autodiff, we should have a checkpointing strategy for better memory consumption (see for instance https://www-sop.inria.fr/tropics/papers/DauvergneHascoet06.pdf) .

Currently, for most operations run in the forward pass, a state will be saved for the backward pass. The state often consists of a few tensors, so it is needless to say that they accumulate and use a lot of memory.

A way to use less memory for the backward pass would be to, instead of having kept the state in memory, recompute the forward pass of the operation to re-obtain the state, just before computing its backward pass. This will lead to more computations, but less memory consumption.

This leads to a tradeoff between compute and memory. Some operations, like matrix multiplication, are "compute-bound", meaning the bottleneck is generally the actual computations, while some, such as element-wise multiplication, are "memory-bound", meaning the computation is actually so simple that the moving of data is the bottleneck.

For compute-bound operations, it is better to keep the state than to recompute. But for memory-bound operations, we would benefit from recomputing.

Also, if many operations are tagged as memory-bound, this will greatly help fusing kernels with Burn-Fusion, which will be able to fuse kernels transparently during the backward pass.

The current strategy, where every state is saved, would simply become a specific case of the new strategy, where everything is considered compute-bound.

Hi, I'm wondering how the toggle of this strategy should be added into burn's ad graph? The AD tool tapenade behind that paper seems to have an IR and a pair of directives to control which snippets should be treated with checkpoints.

Hi @AuruTus
To be honest, haven't read the paper. I just thought the figures seem to explain well the concept of checkpointing, I'm not sure if we should follow what they did or conceptualize our own checkpointing strategy.
We plan on tackling that issue early 2024; for now I haven't given it more thought than what is written above!

If you are discussing check-pointing strategies, it may be worth considering Jax's approach to AD, explained in You Only Linearize Once, since that can shed some light on what is going on with checkpointing. The idea is to break the the vector-jacobian product into two pieces - I'm going to use Haskell type signatures where -o is linear implication and ! means a variable may be reused (e.g. it is a smooth argument).

  1. A jacobian-vector product jvp: (!a -o b) -> ((!a,a) -o b) (this is quite easy to implement). If you set this up correctly you're guaranteed to have nested derivatives that work out of the box.
  2. A linear transpose t: (a -o b) -> (b -o a) - the linear map is representable as a matrix, this is the same thing as multiplying by the transpose of the matrix. This generally where the 'tape' comes in handy. Using currying this can become t_contextual: ((!a, b) -o c) -> ((!a, c) -o b), but realistically you probably want t_contextual as your primitive (that's what we did here).
  3. You can now define vjp = t_contextual . jvp, or vjp f = t_contextual (jvp f). If you set it up this way you're guaranteed to have nested forward/reverse derivatives that work correctly.
  4. (A nice little bonus is that if you've already have a vjp operator then you can use it to define the linear transpose. So this really means you only need to write vjp's for linear functions, and jvps for all functions.)

Now, I said that you want to treat t_contextual as a primitive, but that's really for the user - that's where all of the logic around check-pointing lives when you look at section 6 of 'You Only Linearize Once'. I've mostly changed my focus to animation/simulation these days, but I do know a few Math/CS profs in Canada who are still working on these things.

Solved in #1358