Strange computational graph issue with `gradient_ascent` and `LightningModule`
Closed this issue · 6 comments
First here's a working simple example of running gradient_ascent
that works without error:
from functools import partial
import torch as tr
from torchvision import models
from rai_toolbox.optim import L2ProjectedOptim
from rai_toolbox.perturbations.solvers import gradient_ascent
model = models.resnet18()
data = tr.rand(10, 3, 100, 100, dtype=tr.float)
target = tr.randint(0, 2, size=(10,))
pert = partial(
gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
)
# run gradient ascent
pert(model=model, data=data, target=target)
Now setup and run the same thing using Trainer.predict
:
import pytorch_lightning as pl
class Lit(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = model
self.pert = pert
def predict_step(self, batch, *args, **kwargs):
data, target = batch
data = self.pert(model=self.model, data=data, target=target)
logits = self.model(data)
return logits.sum()
trainer = pl.Trainer()
trainer.predict(
Lit(),
datamodule=pl.LightningDataModule.from_datasets(
predict_dataset=tr.utils.data.TensorDataset(data, target),
batch_size=1,
num_workers=0,
),
)
Here we get the following error:
...
/tmp/ipykernel_74682/1909129363.py in predict_step(self, batch, *args, **kwargs)
27 def predict_step(self, batch, *args, **kwargs):
28 data, target = batch
---> 29 data = self.pert(model=self, data=data, target=target)
30 logits = self.model(data)
31 return logits.sum()
~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/solvers.py in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
277 # Update the perturbation
278 optim.zero_grad(set_to_none=True)
--> 279 loss.backward()
280 optim.step()
281
~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
394 create_graph=create_graph,
395 inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
397
398 def register_hook(self, hook):
~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
171 # some Python versions print out the first line of a multi-line function
172 # calls in the traceback and some print out the last line
--> 173 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
174 tensors, grad_tensors_, retain_graph, create_graph, inputs,
175 allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
If I enter debug everything seems to be setup correctly except that pmodel(data)
does not return a tensor with grad_fn
!!
#
# pdb at `loss.backward()` line
#
> pmodel.delta.requires_grad
True
> tr.is_grad_enabled()
True
> pmodel.delta + data
... # tensor output without `grad_fn`
# try reinitializing
> perturbation_model(data)(data)
... # tensor output WITH `grad_fn`
I have no idea how to debug this and find out what is wrong.
@rsokl do you get this error in your environment?
Looking through PL issues this may be related to PL 1.7 only and not rai-toolbox. Haven't found an issue with a solution yet though.
My initial reaction is: isn't predict
expected to run in no-grad mode? I would expect that we would use a fully-trained/fitted perturbation model in this context rather than run gradient ascent.
That being said... we do explicitly enable grad within gradient-ascent:
responsible-ai-toolbox/src/rai_toolbox/perturbations/solvers.py
Lines 269 to 275 in caebefe
So PL must be doing something really weird, like forcing no-grad mode in a global context
I can reproduce what you see on my end using pytorch-lightning 1.7.2.
Here is a further simplified version that demonstrates that PL disables all gradient-tracking within the context of predict:
import pytorch_lightning as pl
import torch as tr
from rai_toolbox.perturbations import AdditivePerturbation
class Lit(pl.LightningModule):
def __init__(self):
super().__init__()
def predict_step(self, batch, *args, **kwargs):
x = tr.tensor(1.0, requires_grad=True)
y = tr.tensor(2.0, requires_grad=True)
with tr.enable_grad():
assert x.requires_grad and y.requires_grad
z = x * y
assert z.requires_grad
trainer = pl.Trainer()
trainer.predict(
Lit(),
datamodule=pl.LightningDataModule.from_datasets(
predict_dataset=tr.utils.data.TensorDataset(tr.tensor([1.0]), tr.tensor([1.0])),
batch_size=1,
num_workers=0,
),
)
Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
c:\Users\rsokl\responsible-ai-toolbox\scratch\scratch.ipynb Cell 3 in <cell line: 21>()
17 assert z.requires_grad
20 trainer = pl.Trainer()
---> 21 trainer.predict(
22 Lit(),
23 datamodule=pl.LightningDataModule.from_datasets(
24 predict_dataset=tr.utils.data.TensorDataset(tr.tensor([1.0]), tr.tensor([1.0])),
25 batch_size=1,
26 num_workers=0,
27 ),
28 )
File c:\Users\rsokl\miniconda3\envs\rai\lib\site-packages\pytorch_lightning\trainer\trainer.py:951, in Trainer.predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
926 r"""
927 Run inference on your data.
928 This will call the model forward function to compute predictions. Useful to perform distributed
(...)
948 Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
949 """
950 self.strategy.model = model or self.lightning_module
--> 951 return self._call_and_handle_interrupt(
952 self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
953 )
...
15 assert x.requires_grad and y.requires_grad
16 z = x * y
---> 17 assert z.requires_grad
AssertionError:
I'll go ahead and close this given that we know that this is not related to the design of the tooldbox. But I am happy to continue to chat about it.
Yeah, I need to figure out how to get around this with PL. Definitely not good that PL does this as we do have metrics for robustness that require the gradient of the model.
OK, here's the issue and how to override.
First, torch
has a new mode I was unaware of torch.inference_mode
. PL uses this in test
and predict
stages here.
Examples of behaviors:
x = tr.tensor(1.0, requires_grad=True)
y = tr.tensor(2.0)
with tr.no_grad():
assert x.requires_grad
z = x * y
# no grad in in context
assert not z.requires_grad
# no grad is overridden
with tr.enable_grad():
z = x * y
assert z.requires_grad
with tr.inference_mode():
assert x.requires_grad
z = x * y
# no grad in context
assert not z.requires_grad
# no grad cannot be overridden by enable grad
with tr.enable_grad():
z = x * y
assert not z.requires_grad
# no grad can be overriden with inference_mode context
with tr.inference_mode(mode=False):
z = x * y
assert z.requires_grad
I don't know if we want to check for this in gradient_ascent
, I can certainly override this in a LightningModule
.