Freeze/frozen should detach non-leaf tensors
Closed this issue · 0 comments
rsokl commented
Edit: This actually isn't "fixable" in freeze. We can only document this requirement and have utils that use freeze
internally detach non-leaf tensors (see #27)
>>> freeze(2*tr.arange(2.0, requires_grad=True))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [57], in <cell line: 1>()
----> 1 freeze(2*tr.arange(2.0, requires_grad=True))
File c:\users\ryan soklaski\responsible-ai-toolbox\src\rai_toolbox\_utils\stateful.py:84, in freeze(*items)
79 if param not in param_states:
80 # we need to check to see if we have already encountered a parameter
81 # so that we avoid overwriting its original state during a second
82 # encounter
83 param_states[param] = param.requires_grad
---> 84 param.requires_grad_(False)
86 def restore_state():
87 for p, requires_grad in param_states.items():
RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().