mit-ll-responsible-ai/responsible-ai-toolbox

Strange computational graph issue with `gradient_ascent` and `LightningModule`

Closed this issue · 6 comments

jgbos commented

First here's a working simple example of running gradient_ascent that works without error:

from functools import partial
import torch as tr
from torchvision import models
from rai_toolbox.optim import L2ProjectedOptim
from rai_toolbox.perturbations.solvers import gradient_ascent

model = models.resnet18()
data = tr.rand(10, 3, 100, 100, dtype=tr.float)
target = tr.randint(0, 2, size=(10,))
pert = partial(
    gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
)

# run gradient ascent
pert(model=model, data=data, target=target)

Now setup and run the same thing using Trainer.predict:

import pytorch_lightning as pl

class Lit(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = model
        self.pert = pert

    def predict_step(self, batch, *args, **kwargs):
        data, target = batch
        data = self.pert(model=self.model, data=data, target=target)
        logits = self.model(data)
        return logits.sum()

trainer = pl.Trainer()
trainer.predict(
    Lit(),
    datamodule=pl.LightningDataModule.from_datasets(
        predict_dataset=tr.utils.data.TensorDataset(data, target),
        batch_size=1,
        num_workers=0,
    ),
)

Here we get the following error:

...
/tmp/ipykernel_74682/1909129363.py in predict_step(self, batch, *args, **kwargs)
     27     def predict_step(self, batch, *args, **kwargs):
     28         data, target = batch
---> 29         data = self.pert(model=self, data=data, target=target)
     30         logits = self.model(data)
     31         return logits.sum()

~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/solvers.py in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
    277             # Update the perturbation
    278             optim.zero_grad(set_to_none=True)
--> 279             loss.backward()
    280             optim.step()
    281 

~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    394                 create_graph=create_graph,
    395                 inputs=inputs)
--> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    397 
    398     def register_hook(self, hook):

~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    171     # some Python versions print out the first line of a multi-line function
    172     # calls in the traceback and some print out the last line
--> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If I enter debug everything seems to be setup correctly except that pmodel(data) does not return a tensor with grad_fn!!

#
# pdb at `loss.backward()` line
#
> pmodel.delta.requires_grad
True

> tr.is_grad_enabled()
True

> pmodel.delta + data
... # tensor output without `grad_fn`

# try reinitializing
> perturbation_model(data)(data)
... # tensor output WITH `grad_fn`

I have no idea how to debug this and find out what is wrong.

@rsokl do you get this error in your environment?

jgbos commented

Looking through PL issues this may be related to PL 1.7 only and not rai-toolbox. Haven't found an issue with a solution yet though.

rsokl commented

My initial reaction is: isn't predict expected to run in no-grad mode? I would expect that we would use a fully-trained/fitted perturbation model in this context rather than run gradient ascent.

rsokl commented

That being said... we do explicitly enable grad within gradient-ascent:

with frozen(*to_freeze), evaluating(*packed_model), tr.enable_grad():
for _ in range(steps):
# Calculate the gradient of loss
xadv = pmodel(data)
logits = model(xadv)
losses = criterion(logits, target)
loss = reduction_fn(losses)

So PL must be doing something really weird, like forcing no-grad mode in a global context

rsokl commented

I can reproduce what you see on my end using pytorch-lightning 1.7.2.

Here is a further simplified version that demonstrates that PL disables all gradient-tracking within the context of predict:

import pytorch_lightning as pl
import torch as tr

from rai_toolbox.perturbations import AdditivePerturbation


class Lit(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def predict_step(self, batch, *args, **kwargs):
        x = tr.tensor(1.0, requires_grad=True)
        y = tr.tensor(2.0, requires_grad=True)
        with tr.enable_grad():
            assert x.requires_grad and y.requires_grad
            z = x * y
            assert z.requires_grad


trainer = pl.Trainer()
trainer.predict(
    Lit(),
    datamodule=pl.LightningDataModule.from_datasets(
        predict_dataset=tr.utils.data.TensorDataset(tr.tensor([1.0]), tr.tensor([1.0])),
        batch_size=1,
        num_workers=0,
    ),
)
Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
c:\Users\rsokl\responsible-ai-toolbox\scratch\scratch.ipynb Cell 3 in <cell line: 21>()
     17             assert z.requires_grad
     20 trainer = pl.Trainer()
---> 21 trainer.predict(
     22     Lit(),
     23     datamodule=pl.LightningDataModule.from_datasets(
     24         predict_dataset=tr.utils.data.TensorDataset(tr.tensor([1.0]), tr.tensor([1.0])),
     25         batch_size=1,
     26         num_workers=0,
     27     ),
     28 )

File c:\Users\rsokl\miniconda3\envs\rai\lib\site-packages\pytorch_lightning\trainer\trainer.py:951, in Trainer.predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
    926 r"""
    927 Run inference on your data.
    928 This will call the model forward function to compute predictions. Useful to perform distributed
   (...)
    948     Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
    949 """
    950 self.strategy.model = model or self.lightning_module
--> 951 return self._call_and_handle_interrupt(
    952     self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
    953 )
...
     15 assert x.requires_grad and y.requires_grad
     16 z = x * y
---> 17 assert z.requires_grad

AssertionError: 

I'll go ahead and close this given that we know that this is not related to the design of the tooldbox. But I am happy to continue to chat about it.

jgbos commented

Yeah, I need to figure out how to get around this with PL. Definitely not good that PL does this as we do have metrics for robustness that require the gradient of the model.

jgbos commented

OK, here's the issue and how to override.

First, torch has a new mode I was unaware of torch.inference_mode. PL uses this in test and predict stages here.

Examples of behaviors:

x = tr.tensor(1.0, requires_grad=True)
y = tr.tensor(2.0)

with tr.no_grad():
    assert x.requires_grad
    z = x * y

    # no grad in in context
    assert not z.requires_grad

    # no grad is overridden
    with tr.enable_grad():
        z = x * y
        assert z.requires_grad


with tr.inference_mode():
    assert x.requires_grad
    z = x * y

    # no grad in context
    assert not z.requires_grad

    # no grad cannot be overridden by enable grad
    with tr.enable_grad():
        z = x * y
        assert not z.requires_grad

    # no grad can be overriden with inference_mode context
    with tr.inference_mode(mode=False):
        z = x * y
        assert z.requires_grad

I don't know if we want to check for this in gradient_ascent, I can certainly override this in a LightningModule.