mit-ll-responsible-ai/responsible-ai-toolbox

Freeze/frozen should detach non-leaf tensors

Closed this issue · 0 comments

rsokl commented

Edit: This actually isn't "fixable" in freeze. We can only document this requirement and have utils that use freeze internally detach non-leaf tensors (see #27)

>>> freeze(2*tr.arange(2.0, requires_grad=True))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [57], in <cell line: 1>()
----> 1 freeze(2*tr.arange(2.0, requires_grad=True))

File c:\users\ryan soklaski\responsible-ai-toolbox\src\rai_toolbox\_utils\stateful.py:84, in freeze(*items)
     79         if param not in param_states:
     80             # we need to check to see if we have already encountered a parameter
     81             # so that we avoid overwriting its original state during a second
     82             # encounter
     83             param_states[param] = param.requires_grad
---> 84             param.requires_grad_(False)
     86 def restore_state():
     87     for p, requires_grad in param_states.items():

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().