Gradient checkpointing for weight noise etc in PyTorch
albertz opened this issue · 7 comments
For variational noise (weight noise) or weight dropout or similar things, it would be very helpful to have gradient checkpointing where we can avoid that the weights are stored twice in memory, once with the noise, once without.
(Also see RF weight dropout discussion: #1518)
See this current code as an example for variational noise in our TF code:
if param_variational_noise and param.dtype.is_floating and isinstance(param, tf.Variable):
with default_control_flow_ctx(): # make independent from loop/cond
with reuse_name_scope_of_tensor(param, postfix="_variational_noise", add_tensor_name=True):
def _apply_var_noise():
rnd_state = tf_util.StatelessRandomSeed.create(shape=tf_util.get_shape(param))
with gradient_checkpoint_scope():
noise = rnd_state.normal(stddev=param_variational_noise, dtype=param.dtype.base_dtype)
return param + noise
param = self.network.cond_on_train(
fn_train=_apply_var_noise,
fn_eval=lambda: param,
)
Specifically, check the code of gradient_checkpoint_scope
and prepare_gradient_checkpointing
.
There are also a couple of questions regarding how to implement dropout using gradient checkpointing and a stateless random number generator in RF and/or PT. E.g.:
- How to get stateless random numbers in PT? I think the best (only?) way is to create a new
torch.Generator
? I opened pytorch/pytorch#128126 about this. - Define a stateless random number API for RF? (Check maybe also JAX.)
- How to do the gradient checkpointing in RF? A new API for that? Currently
rf.dropout
is directly implemented using other RF methods, and this is nice and clean (you directly see the actual dropout implementation, it's not hidden away), and I think it should stay that way. But then this means we need to somehow extend this RF implementation by specifying the gradient checkpointing logic. And the gradient checkpointing should be optional.- See our own
gradient_checkpoint_scope
andprepare_gradient_checkpointing
and how we implemented variational noise. - See
torch.utils.checkpoint
andtf.recompute_grad
for related functionality. Although I'm not sure that this API really allows for what we want (that the output is not stored in memory). But we could look at the internals of how it is done (specifically PyTorch), and maybe reuse that functionality to implement what we want.
- See our own
I'm reading the code of _checkpoint_without_reentrant_generator
. It looks like this uses a couple of techniques which are very relevant for what we need:
There is logic for preserve_rng_state
. It gets the random state (CPU and device RNGs) and then torch.random.fork_rng
("Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.").
Further, it uses torch.autograd.graph.saved_tensors_hooks
(tutorial). That is the crucial functionality to control what tensors are stored for backprop, and how to reconstruct them.
The current torch.utils.checkpoint
doesn't quite do what we want, as it does not control whether further ops after the checkpoint
maybe store the tensor (which is what we want to avoid. I think one solution to our problem is to just defer the exit of the saved_tensors_hooks
to some later point (when the tensor is not used further), because having saved_tensors_hooks
active on the whole model probably gives some slowdown? I'm not exactly sure how to detect when it is not used anymore though. (Can I hook the tensor __del__
?) And also, the whole logic around using saved_tensors_hooks
in checkpoint
seems quite involved and tricky (although the basic idea behind it is simple).
Can I hook the tensor
__del__
?
One solution to this (via):
class _TensorHandle:
def __init__(self, tensor):
self.tensor_ref = weakref.ref(tensor) # not really used here...
def __del__(self):
print("out-of-scope")
def func():
x = torch.zeros(2, 3)
x._my_handle = _TensorHandle(x)
del x # now (or later via GC) the _TensorHandle.__del__ will be called
...
I wonder a bit about torch.utils.checkpoint
: Why do I need to specify the inputs for the function? What happens when I forget some? E.g. the example speaks about LSTM and passing (activation, hidden)
to it. So another input which is missing here are the parameters. So what about them? They are not needed to be specified? Or what happens when you specify them? What happens when you don't specify them but still use them? If this does not matter, why specify the inputs at all?
I also wonder how torch.utils.checkpoint
is supposed to be used in practice. The outputs of the checkpointing function are potentially anyway stored elsewhere for backprop, as those are not under the checkpoint
anymore. E.g. consider the example:
y = (a + b) * c
And now checkpoint
around the a + b
:
a = ...
b = ...
c = ...
x = checkpoint(lambda (_a, _b): _a + _b, a, b)
y = x * c
It doesn't really make sense to recompute x = a + b
here, because the result is anyway stored for the backprop of x * c
. Or is the intended usage for this example actually like:
a = ...
b = ...
c = ...
y = checkpoint(lambda (_a, _b, _c): (_a + _b) * c, a, b, c)
Or if explicit args are not necessary (see my previous comment):
a = ...
b = ...
c = ...
y = checkpoint(lambda: (a + b) * c)
Edit I also asked the question in the PyTorch discussion forum.
I also asked the question in the PyTorch discussion forum.
I got some quite useful answers which clarify my questions.
I have a very hacky idea: I can somehow delay the _checkpoint_hook.__exit__
to some later point when y._my_del_handler
or so goes out of scope...
Edit Ok, I have now some idea for a complete solution (see the discussion thread) to implement sth like our TF gradient_checkpoint_scope
.
With gradient_checkpoint_scope
, the example would look like this:
a = ...
b = ...
c = ...
with gradient_checkpoint_scope():
x = a + b
y = x * c
It would not store the x
for backprop, but recompute x = a + b
for backprop. Everything else (a, b, c, y
) will be stored (if needed).
So, how to implement the same logic in PyTorch, i.e. the same gradient_checkpoint_scope
API?
Very similar to the example of soulitzer for a new way to do automatic checkpoiting, I could use __torch_dispatch__
to record the computation graph and also record all tensors which are created there. If I would use saved_tensors_hooks
on the whole model, I can check whenever some of the tensors come from such checkpoint scope, and if so, recompute that.
If I don't want to have saved_tensors_hooks
on the whole model, I can solve this too: I can saved_tensors_hooks.__enter__
when I enter the checkpoint scope, but don't immediately saved_tensors_hooks.__exit__
afterwards. Instead, I can hook into Tensor.__del__
(as described above) for all tensors created in that scope, and once they all get deleted, then I can do saved_tensors_hooks.__exit__
(and also exit the __torch_dispatch__
). So this stays very local then.
It gets a bit trickier to handle the RNG state correctly, and maybe also handle AMP. I guess I need to make sure to store and then later fork and reset the state correctly, and also to replay the recorded computations exactly in the same order to make sure it's all deterministic. But this is similar to the current Torch checkpoint
logic.
So, this sounds like a plan... -> Edit Implemented in PR #1559.
Btw, I also have an idea on how to write a test case for this: We use some very similar code as in the example. Once without gradient checkpointing:
x = a + b
y = x * c
Once with gradient checkpointing:
with gradient_checkpoint_scope():
x = a + b
y = x * c
a, b, c
should have required_grad=True
.
And we record the memory consumption for both, using some of the Torch mechanisms (e.g. via torch.profiler
using profile_memory=True
). Actually we only need to check memory consumption before the run and after the run (only counting Torch tensors), and it should be less for the second case. Actually exactly by the amount of tensor x
. Or maybe we could also (maybe additionally) do a weakref on x
and check if it is still alive in the second case.
Then we also should do backprop to verify that we get the same gradients in both cases.
The test could also be extended by using the random number generator, so see that the restoring of the RNG state works correctly.
Note, one problem with deferring the saved_tensors_hook.__exit__
: If there are any other (unrelated) saved_tensors_hook.__enter__
in the meantime, and then we call our saved_tensors_hook.__exit__
while the new scope is active, this will lead to the wrong behavior.
Inside RETURNN, this is likely not so much a problem, as we don't really have any other saved_tensors_hook
usage.
Note, one problem with putting any logic in __del__
: This can potentially run in a different thread. Calling saved_tensors_hook.__exit__
is thus wrong, as this logic is thread-local. It must execute in the same thread.
I'm not sure about the best solution. Or actually any solution which is not ugly.
I also reported that here: pytorch/pytorch#129867
Currently the only solution I can think of is to hook into saved_tensors_hook
itself (overwriting the __enter__
and __exit__
by own code), and add custom logic to handle all cases correctly. -> Edit This is what I implemented now in the PR #1559.