patrick-kidger/signatory

Why not use autograd directly?

sdogsq opened this issue · 2 comments

Hi, Patrick,

Thanks a lot for your nice work and detailed code comments! I have a very naive question: why do not use pytorch autograd directly for backward process? Since I see the tensor operations are all like addcmul in forward process. I have some simple reasons but I am not sure if they are correct.

  1. custom backward process might have speed up comparing to autograd
  2. it can avoid intermediate variables and thus save memory

I am a newbie in custom pytorch functions. I would appreciate it if you could kindly share some opinions. Thank you again!

The main reason is that it is possible to efficiently reconstruct the forward pass during the backward pass. Doing so means we don't need to hold intermediate values in memory.

This is actually a very special case of the "continuous adjoint method" sometimes used in differential equations (e.g. as popularised for neural ODEs; also see Chapter 5 of https://arxiv.org/abs/2202.02435). Although in our case, because of the piecewise linear structure, we can recompute things without suffering any numerical truncation error. (Only floating point error, which usually isn't that bad.)

Really insightful views! I'll read this paper carefully.

Cheers!