`ArcTape`: Add another `Tape` impl?
emchristiansen opened this issue · 4 comments
What do you think of adding another Tape
impl, one that basically just wraps a global shared OwnedTape
inside an Arc<Mutex<_>>
(or whatever the best wrapper type would be)?
I'm trying to avoid the tape-tracking headache described in this thread, and it would be nice if each operation just automatically updated the global gradient tape without any work from the user.
Basically, I have several outputs, and I'm afraid I could introduce a bug by calling backward
on an output that doesn't own the full tape.
In other words, IIUC the current design is only guaranteed to do the correct thing when all the inputs converge to a single output (when the op dependency graph is a tree).
But my op dependency graph is a DAG with N
outputs, so N - 1
of my outputs will have incomplete tapes.
I definitely see the headache of tracking the tapes. I think Arc<Mutex<>> makes sense actually, and might be fairly straightforward. I'll put up a PR in a sec
Out of curiosity how do you call backward on the n separate outputs? If you combine them all into 1 loss value then the final tape will have the correct values.
I'm taking a look at your branch now.
FYI it looks like you duplicated the merge logic - maybe just delegate to the existing logic?
Right now I'm just calling backward
on one output per forward pass, as part of my tests / debugging to ensure that the gradients are being calculated properly.
It would be nice to be able to call backward
multiple times (or maybe once on a Vec of outputs), so I could get the full set of gradients.
Looks good!
Here's the example using "ArcTape", and here's one using it with the module API.
TBH there are still some things that don't make sense to me about the Module API:
- I still don't understand why you tag the model inputs with the tape and not the model parameters, assuming all you care about is updating the parameters. If you tagged the parameters (and not the inputs), then the tape DAG would be the minimal DAG necessary to compute the parameter gradients, and it seems to me that's what you want?
- Also it seem incorrect to not tag the model parameters. E.g. what if my model takes no inputs and returns a learned value (i.e. it is a learned constant fn)? Without inputs you wouldn't be able to compute gradients in this scheme.
- I don't understand what
model.alloc_grads
is supposed to do. Is it just doing memory pre-allocation for the model param gradients? If so that seems like a small win, because those will be a small fraction of the gradients in the tape DAG.
The tape is different from "this parameter requires gradients". The tape is just a gradients holder. What dictates whether a parameter will have gradients is whether it used in an operation with a tape.
what if my model takes no inputs and returns a learned value (i.e. it is a learned constant fn)? Without inputs you wouldn't be able to compute gradients in this scheme.
It would require a different setup than you would do in pytorch, but you could always make the Module implementation of this return a tape that could be used when computing losses.
I don't understand what model.alloc_grads is supposed to do. Is it just doing memory pre-allocation for the model param gradients? If so that seems like a small win, because those will be a small fraction of the gradients in the tape DAG.
Yeah its pre-allocation so you don't have to allocate & deallocate gradients for model parameters every iteration. It's actually big wins because usually the model parameters are much bigger than intermediate tensors. E.g. a Linear<n, n>
layer will have gradients of size n^2
, but the input & output will be n