shap/shap

BUG: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Functio

bedanar opened this issue · 3 comments

Issue Description

I am having the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[490], line 15
     13 print(background.shape)
     14 # Get SHAP values
---> 15 shap_values = explainer.shap_values(test_input)
     17 # Reshape shap values and images for plotting
     18 shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))

File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/__init__.py:125, in DeepExplainer.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
     91 def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
     92     """ Return approximate SHAP values for the model applied to the data given by X.
     93 
     94     Parameters
   (...)
    123         were chosen as "top".
    124     """
--> 125     return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)

File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/deep_pytorch.py:191, in PyTorchDeep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    189 # run attribution computation graph
    190 feature_ind = model_output_ranks[j, i]
--> 191 sample_phis = self.gradient(feature_ind, joint_x)
    192 # assign the attributions to the right part of the output arrays
    193 if self.interim:

File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/deep_pytorch.py:107, in PyTorchDeep.gradient(self, idx, inputs)
    105 self.model.zero_grad()
    106 X = [x.requires_grad_() for x in inputs]
--> 107 outputs = self.model(*X)
    108 selected = [val for val in outputs[:, idx]]
    109 grads = []

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:541, in ResNet.forward(self, x)
    540 def forward(self, x):
--> 541     x = self.forward_features(x)
    542     x = self.forward_head(x)
    543     return x

File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:528, in ResNet.forward_features(self, x)
    526     x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
    527 else:
--> 528     x = self.layer1(x)
    529     x = self.layer2(x)
    530     x = self.layer3(x)

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
    213 def forward(self, input):
    214     for module in self:
--> 215         input = module(input)
    216     return input

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:115, in BasicBlock.forward(self, x)
    113 if self.downsample is not None:
    114     shortcut = self.downsample(shortcut)
--> 115 x += shortcut
    116 x = self.act2(x)
    118 return x

RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.

I am doing an image classification interpretation using shap:

shap_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(cfg.device)

#Create SHAP explainer 
explainer = shap.DeepExplainer(model, background)

# Load test images
test_images = [Image.open(path) for path in ex_paths]
# test_images = np.array(test_images)
test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(cfg.device)
# Get SHAP values
shap_values = explainer.shap_values(test_input)

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

shap_values = explainer.shap_values(test_input) shows this error.

I have tried to remove inlace = True from ReLU layers with the following function, but the error is anyway there:

def ReLU_inplace_to_False(module):
    for layer in module._modules.values():
        if isinstance(layer, nn.ReLU):
            layer.inplace = False
        ReLU_inplace_to_False(layer)

I want to get an interpretation of the model. I am using a retrained resnet18 model from the time models:

cfg.n_classes = 2
cfg.backbone = 'resnet18'

model = timm.create_model(cfg.backbone, 
                          pretrained = True, 
                          num_classes = cfg.n_classes)

Minimal Reproducible Example

import shap

Traceback

No response

Expected Behavior

No response

Bug report checklist

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest release of shap.
  • I have confirmed this bug exists on the master branch of shap.
  • I'd be interested in making a PR to fix this bug

Installed Versions

shap.version = '0.44.0'

I am getting the same issue on shap version = 0.43.0

Check this issue: #3466
There is also a problem with x += shortcut inplace operation in the resnet forward function. Not just the nn.ReLU(inplace=True)

Thanks for the bug report. This looks like a duplicate so let's track this one on #3466