rwth-i6/returnn

Torch gradient_checkpoint_scope _unregister_custom_saved_tensors_hooks error

albertz opened this issue · 4 comments

(It's a multi GPU training (not relevant here) and the stack trace mixes the error from all workers. Trying to disentangle this here. But it might be slightly messed up.)

...
start epoch 1 global train step 0 with effective learning rate 1.0003355932203391e-05 ...
...
Module call stack:
(Model.__call__) (root)
(BatchNorm.__call__) feature_batch_norm
...
  File "/u/zeyer/setups/combined/2021-05-31/recipe/i6_experiments/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py", line 697, in ctc_training
    line: logits, enc, enc_spatial_dim = model(data, in_spatial_dim=data_spatial_dim, collected_outputs=collected_outputs)
    locals:
      logits = <not found>
      enc = <not found>
      enc_spatial_dim = <not found>
      model = <local> <Model>
      data = <local> Tensor{'data', [B?,T|'time'[B?]]}
      in_spatial_dim = <not found>
      data_spatial_dim = <local> Dim{'time'[B?]}
      collected_outputs = <local> {}
  File "/u/zeyer/setups/combined/2021-05-31/recipe/i6_experiments/users/zeyer/experiments/exp2024_04_23_baselines/ctc.py", line 1006, in Model.__call__
    line: source = self.feature_batch_norm(source)
    locals:
      source = <local> Tensor{'mul', [B?,'⌈((-199+time)+-200)/160⌉'[B?],F|F'logmel'(80)]}
      self = <local> <Model>
      self.feature_batch_norm = <local> <ParametrizedBatchNorm>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/normalization.py", line 195, in BatchNorm.__call__
    line: update_running_stats = self.running_mean is not None and rf.get_run_ctx().train_flag
    locals:
      update_running_stats = <not found>
      self = <local> <ParametrizedBatchNorm>
      self.running_mean = <local> !RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cpu!
      rf = <global> <module 'returnn.frontend' from '/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/__init__.py'>
      rf.get_run_ctx = <global> <function get_run_ctx at 0x7c221e3f4d60>
      train_flag = <not found>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrize.py", line 188, in _Property.__get__
    locals:
    line: return self.parametrization(self.orig_param)
      main = <local> <function main at 0x7216050ba2a0>
    locals:

...
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrizations.py", line 87, in WeightNoise.__call__
    line: return rf.cond(rf.get_run_ctx().train_flag, _on_train, lambda: param)
    locals:
      rf = <global> <module 'returnn.frontend' from '/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/__init__.py'>
      rf.cond = <global> <function cond at 0x721604f33b00>
      rf.get_run_ctx = <global> <function get_run_ctx at 0x721604fd0d60>
      train_flag = <not found>
      _on_train = <local> <function WeightNoise.__call__.<locals>._on_train at 0x7214b5f47920>
      param = <local> Tensor{'parameter', [F'logmel'(80)]}
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/cond.py", line 27, in cond
    line: return true_fn()
    locals:
      true_fn = <local> <function WeightNoise.__call__.<locals>._on_train at 0x7214b5f47920>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrizations.py", line 83, in WeightNoise.__call__.<locals>._on_train
    line: with rf.gradient_checkpoint_scope():
    locals:
      rf = <global> <module 'returnn.frontend' from '/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/__init__.py'>
      rf.gradient_checkpoint_scope = <global> <function gradient_checkpoint_scope at 0x721604f53560>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 150, in gradient_checkpoint_scope.__enter__
    line: self.saved_tensors_hooks_scope.__enter__()
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x72162b000ed0>
      self.saved_tensors_hooks_scope = <local> <torch.autograd.graph.saved_tensors_hooks object at 0x7215402ceb90>
      self.saved_tensors_hooks_scope.__enter__ = <local> <bound method _custom_saved_tensors_hooks_enter of <torch.autograd.graph.saved_tensors_hooks object at 0x7215402ceb90>>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 544, in _custom_saved_tensors_hooks_enter
    line: _custom_saved_tensors_hooks_call_callbacks()
    locals:
      _custom_saved_tensors_hooks_call_callbacks = <global> <function _custom_saved_tensors_hooks_call_callbacks at 0x7214b5fddb20>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 590, in _custom_saved_tensors_hooks_call_callbacks
    line: _custom_saved_tensors_hooks_tls_ctx.callbacks = [
              cb for cb in _custom_saved_tensors_hooks_tls_ctx.callbacks if cb()
          ]
    locals:
      _custom_saved_tensors_hooks_tls_ctx = <global> <_thread._local object at 0x7214b5fec7c0>
      _custom_saved_tensors_hooks_tls_ctx.callbacks = <global> [<returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7214b5f957d0>]
      cb = <not found>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 591, in _custom_saved_tensors_hooks_call_callbacks.<locals>.<listcomp>
    line: cb for cb in _custom_saved_tensors_hooks_tls_ctx.callbacks if cb()
    locals:
      cb = <local> <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7214b469cb50>
      _custom_saved_tensors_hooks_tls_ctx = <global> <_thread._local object at 0x7214b5fec7c0>
      _custom_saved_tensors_hooks_tls_ctx.callbacks = <global> [<returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7214b5f957d0>]
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 490, in _WeakMethod.__call__
    line: return self.func(obj, *args, **kwargs)
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7214b469cb50>
      self.func = <local> <function gradient_checkpoint_scope._custom_saved_tensors_hooks_callback at 0x7214b5fdc4a0>
      obj = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x7215402dc210>
      args = <local> ()
      kwargs = <local> {}
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 233, in gradient_checkpoint_scope._custom_saved_tensors_hooks_callback
    line: self.exit_saved_tensors_hooks_scope()
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x7215402dc210>
      self.exit_saved_tensors_hooks_scope = <local> <bound method gradient_checkpoint_scope.exit_saved_tensors_hooks_scope of <returnn.torch.util.gradient_checkpoint.gradient_checkpoi
nt_scope object at 0x7215402dc210>>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 198, in gradient_checkpoint_scope.exit_saved_tensors_hooks_scope
    line: self.saved_tensors_hooks_scope.__exit__(*self.exit_args)
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x7215402dc210>
      self.saved_tensors_hooks_scope = <local> <torch.autograd.graph.saved_tensors_hooks object at 0x7215402dff50>
      self.saved_tensors_hooks_scope.__exit__ = <local> <bound method _custom_saved_tensors_hooks_exit of <torch.autograd.graph.saved_tensors_hooks object at 0x7215402dff50>>
      self.exit_args = <local> (None, None, None)
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 574, in _custom_saved_tensors_hooks_exit
    line: _unregister_custom_saved_tensors_hooks()
    locals:
      _unregister_custom_saved_tensors_hooks = <global> <function _unregister_custom_saved_tensors_hooks at 0x7214b5fdd8a0>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 531, in _unregister_custom_saved_tensors_hooks
    line: assert (
              not _custom_saved_tensors_hooks_tls_ctx.stack
              and not _custom_saved_tensors_hooks_tls_ctx.callbacks
              and not _custom_saved_tensors_hooks_tls_ctx.queued_exits
          )
    locals:
      _custom_saved_tensors_hooks_tls_ctx = <global> <_thread._local object at 0x7214b5fec7c0>
      _custom_saved_tensors_hooks_tls_ctx.stack = <global> [<torch.autograd.graph.saved_tensors_hooks object at 0x7214b5f95e10>]
      _custom_saved_tensors_hooks_tls_ctx.callbacks = <global> [<returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7214b5f957d0>]
      _custom_saved_tensors_hooks_tls_ctx.queued_exits = <global> []
AssertionError

Log: /u/zeyer/setups/combined/2021-05-31/alias/ctc/v6-relPosAttDef-aedLoss-bhv20-11gb-f32-bs15k-accgrad1-mgpu4-pavg100-wd1e_2-vn0025-lrlin1e_5_295k-featBN-speedpertV2-spm10k-bpeSample001/train/engine/i6_core.returnn.training.ReturnnTrainingJob.JIk9Gs5ytgna.run.8024155.1

Despite the crash (which is definitely some issue which needs to be fixed), here is also an example where the code is not really optimal suited for the parametrization because we don't have any caching. Just the check self.running_mean is not None will already trigger one call to the parametrization.

Despite the crash, installing the weight noise on running_mean is anyway a bad idea. I just fixed this now in my setup, to not do that anymore. But that's also independent from the issue here.

The crash is maybe related to using multiple gradient_checkpoint_scope right behind each other. Maybe the previous scope was not yet cleaned up. Our current test does not cover this.

The same now also when excluding any aux vars:

...
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/encoder/conformer.py", line 170, in ConformerConvSubsample.__call__
    line: x, in_spatial_dims = conv_layer(x, in_spatial_dims=in_spatial_dims)
    locals:
      x = <local> Tensor{'mul', [B?,'⌈((-199+time)+-200)/160⌉'[B?],'dummy-input-feature-dim'(1),F|F'logmel'(80)]}
      in_spatial_dims = <local> [Dim{'⌈((-199+time)+-200)/160⌉'[B?]}, Dim{F'logmel'(80)}]
      conv_layer = <local> <ParametrizedConv2d>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/conv.py", line 165, in _Conv.__call__
    line: filter=self.filter,
    locals:
      filter = <builtin> <class 'filter'>
      self = <local> <ParametrizedConv2d>
      self.filter = <local> !RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrize.py", line 188, in _Property.__get__
    line: return self.parametrization(self.orig_param)
    locals:
      self = <local> <returnn.frontend.parametrize._Property object at 0x7c5d4c6d1350>
      self.parametrization = <local> <returnn.frontend.parametrizations.WeightNoise object at 0x7c5d4dfb8b10>
      self.orig_param = <local> Tensor{'parameter', ['conv1'(32),'dummy-input-feature-dim'(1),'filter-dim0'(3),'filter-dim1'(3)]}
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrizations.py", line 87, in WeightNoise.__call__
    line: return rf.cond(rf.get_run_ctx().train_flag, _on_train, lambda: param)
    locals:
      rf = <global> <module 'returnn.frontend' from '/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/__init__.py'>
      rf.cond = <global> <function cond at 0x7c5d2806bc40>
      rf.get_run_ctx = <global> <function get_run_ctx at 0x7c5d28108f40>
      train_flag = <not found>
      _on_train = <local> <function WeightNoise.__call__.<locals>._on_train at 0x7c5c3ef2fc40>
      param = <local> Tensor{'parameter', ['conv1'(32),'dummy-input-feature-dim'(1),'filter-dim0'(3),'filter-dim1'(3)]}
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/cond.py", line 27, in cond
    line: return true_fn()
    locals:
      true_fn = <local> <function WeightNoise.__call__.<locals>._on_train at 0x7c5c3ef2fc40>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/parametrizations.py", line 83, in WeightNoise.__call__.<locals>._on_train
    line: with rf.gradient_checkpoint_scope():
    locals:
      rf = <global> <module 'returnn.frontend' from '/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/frontend/__init__.py'>
      rf.gradient_checkpoint_scope = <global> <function gradient_checkpoint_scope at 0x7c5d2808b740>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 150, in gradient_checkpoint_scope.__enter__
    line: self.saved_tensors_hooks_scope.__enter__()
...

From the exception stack trace, what likely happens:

  • Some previous saved_tensors_hooks_scope is still active. It still has registered the custom saved_tensors_hooks. The self.saved_tensors_hooks_scope.__enter__ calls _custom_saved_tensors_hooks_enter, as you see.
  • In there, the _custom_saved_tensors_hooks_call_callbacks notices that it can exit the scope now. So it does. It first calls exit_saved_tensors_hooks_scope and that calls _custom_saved_tensors_hooks_exit and that also calls _unregister_custom_saved_tensors_hooks.
  • _unregister_custom_saved_tensors_hooks fails now because there are still entries on the stack. This is a bit weird to me. Why are there still entries in the stack when _custom_saved_tensors_hooks_exit was called? It should only call this when the stack is empty:
    if not _custom_saved_tensors_hooks_tls_ctx.stack:
        assert not _custom_saved_tensors_hooks_tls_ctx.queued_exits
        if _custom_saved_tensors_hooks_tls_ctx.active:  # might have been unregistered in the meantime by callbacks
            _unregister_custom_saved_tensors_hooks()
    -> Ah, the stack trace report is likely wrong. The stack trace report is generated after the exception was raised and then not handled, so any further __exit__ handlers would get called, and stack here could have been modified in the meantime.

I now extended the assertion error message to print the real stack, callbacks, queued_exits at that point. This is what I get now:

...
ep 1 train, step 0, ctc_4 40.278, ctc_8 41.216, ctc 40.754, aed_ce 9.333, aed_fer 1.000, num_seqs 40, max_size:time 58169, max_size:out-spatial 17, mem_usage:cuda:0 5.9GB, 8.356 sec/step
ep 1 train, step 0, ctc_4 42.904, ctc_8 43.577, ctc 43.185, aed_ce 9.311, aed_fer 1.000, num_seqs 44, max_size:time 47521, max_size:out-spatial 16, mem_usage:cuda:1 5.4GB, 8.389 sec/s
tep
ep 1 train, step 0, ctc_4 38.827, ctc_8 39.784, ctc 39.285, aed_ce 9.331, aed_fer 1.000, num_seqs 40, max_size:time 59760, max_size:out-spatial 19, mem_usage:cuda:3 6.0GB, 9.188 sec/step
ep 1 train, step 0, ctc_4 36.793, ctc_8 37.697, ctc 37.314, aed_ce 9.321, aed_fer 1.000, num_seqs 41, max_size:time 57729, max_size:out-spatial 18, mem_usage:cuda:2 5.9GB, 8.568 sec/step
AssertionError: _unregister_custom_saved_tensors_hooks: stack [],
 callbacks [<returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ed14d50>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb4e190>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb4e810>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb4f3d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f40b653c190>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ed17210>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb34a50>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f40b654c4d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ed4c990>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb3e090>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eadf0d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eadf210>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ead8ed0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ea91ed0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eadb290>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eadbc10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f40b64fc2d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb4cf10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb35ad0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf6050>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf0890>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf47d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf2290>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ebbbb50>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eac5390>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eac5010>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eac5dd0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f40b654d850>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eac7650>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eac4150>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf7d50>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1f9150>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eb3ddd0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1f9bd0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ef2d190>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ecf3910>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1cafd0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ef3a7d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1c8110>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1f8c10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1d8b10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1d9ad0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1d91d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3faa87ce10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7eaf19d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1fa290>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1fb590>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1c4090>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7d1c6c10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7f3f7ed16d10>,
... (many many more)
<returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d8ae810>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d985650>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d9dcd10>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d93e1d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d956a90>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d9deb90>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d9dd450>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d9f6d50>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67da67dd0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d96e4d0>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d9d8f90>, <returnn.torch.util.gradient_checkpoint._WeakMethod object at 0x7fc67d8e6950>],
 queued_exits [] 

...

...
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 233, in gradient_checkpoint_scope._custom_saved_tensors_hooks_callback
    line: self.exit_saved_tensors_hooks_scope()
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x7f4c02af0690>
      self.exit_saved_tensors_hooks_scope = <local> <bound method gradient_checkpoint_scope.exit_saved_tensors_hooks_scope of <returnn.torch.util.gradient_checkpoint.gradient_checkpoi
nt_scope object at 0x7f4c02af0690>>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 198, in gradient_checkpoint_scope.exit_saved_tensors_hooks_scope
    line: self.saved_tensors_hooks_scope.__exit__(*self.exit_args)
    locals:
      self = <local> <returnn.torch.util.gradient_checkpoint.gradient_checkpoint_scope object at 0x7f4c02af0690>
      self.saved_tensors_hooks_scope = <local> <torch.autograd.graph.saved_tensors_hooks object at 0x7f4c00f64a10>
      self.saved_tensors_hooks_scope.__exit__ = <local> <bound method _custom_saved_tensors_hooks_exit of <torch.autograd.graph.saved_tensors_hooks object at 0x7f4c00f64a10>>
      self.exit_args = <local> (None, None, None)
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 579, in _custom_saved_tensors_hooks_exit
    line: _unregister_custom_saved_tensors_hooks()
    locals:
      _unregister_custom_saved_tensors_hooks = <global> <function _unregister_custom_saved_tensors_hooks at 0x7f4c02919c60>
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/util/gradient_checkpoint.py", line 531, in _unregister_custom_saved_tensors_hooks
    line: assert (
              not _custom_saved_tensors_hooks_tls_ctx.stack
              and not _custom_saved_tensors_hooks_tls_ctx.callbacks
              and not _custom_saved_tensors_hooks_tls_ctx.queued_exits
          ), (
              f"_unregister_custom_saved_tensors_hooks:"
              f" stack {_custom_saved_tensors_hooks_tls_ctx.stack},"
              f" callbacks {_custom_saved_tensors_hooks_tls_ctx.callbacks},"
              f" queued_exits {_custom_saved_tensors_hooks_tls_ctx.queued_exits}"
          )

...

Module call stack:
(Model.__call__) (root)
(ConformerEncoder.__call__) encoder
(ConformerConvSubsample.__call__) encoder.input_layer
(_Conv.__call__) encoder.input_layer.conv_layers.0

I reproduced it in a test case:

def test_gradient_checkpoint_scope_twice():
    # https://github.com/rwth-i6/returnn/issues/1579
    shape = (101, 103)

    class _Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.var = torch.nn.Parameter(torch.randn(shape))
            self.input_var = torch.nn.Parameter(torch.randn(shape))
            self.opt = torch.optim.SGD(self.parameters(), lr=0.1)  # not common to have this here but ok for the test

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.get_var() * x

        def get_var(self) -> torch.Tensor:
            with gradient_checkpoint_scope():
                return self.var + torch.randn(shape)

        def get_input(self) -> torch.Tensor:
            x = self.input_var
            with gradient_checkpoint_scope():
                return x + torch.randn(shape)

        def demo_run(self):
            self.opt.zero_grad()
            y = self(self.get_input())
            loss = y.sum()  # dummy loss
            del y  # not needed anymore
            loss.backward()
            del loss  # not needed anymore
            self.opt.step()

    orig_gradient_checkpoint_scope_tensor_del_hook = gradient_checkpoint_scope._tensor_del_hook
    try:
        # Overwrite this here to trigger the case where the tensor del hook will not do the cleanup.
        gradient_checkpoint_scope._tensor_del_hook = lambda self: None

        model = _Model()
        model.demo_run()
        model.demo_run()

    finally:
        gradient_checkpoint_scope._tensor_del_hook = orig_gradient_checkpoint_scope_tensor_del_hook