BUG: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Functio
bedanar opened this issue · 3 comments
Issue Description
I am having the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[490], line 15
13 print(background.shape)
14 # Get SHAP values
---> 15 shap_values = explainer.shap_values(test_input)
17 # Reshape shap values and images for plotting
18 shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/__init__.py:125, in DeepExplainer.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
91 def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
92 """ Return approximate SHAP values for the model applied to the data given by X.
93
94 Parameters
(...)
123 were chosen as "top".
124 """
--> 125 return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/deep_pytorch.py:191, in PyTorchDeep.shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
189 # run attribution computation graph
190 feature_ind = model_output_ranks[j, i]
--> 191 sample_phis = self.gradient(feature_ind, joint_x)
192 # assign the attributions to the right part of the output arrays
193 if self.interim:
File /opt/homebrew/lib/python3.11/site-packages/shap/explainers/_deep/deep_pytorch.py:107, in PyTorchDeep.gradient(self, idx, inputs)
105 self.model.zero_grad()
106 X = [x.requires_grad_() for x in inputs]
--> 107 outputs = self.model(*X)
108 selected = [val for val in outputs[:, idx]]
109 grads = []
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:541, in ResNet.forward(self, x)
540 def forward(self, x):
--> 541 x = self.forward_features(x)
542 x = self.forward_head(x)
543 return x
File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:528, in ResNet.forward_features(self, x)
526 x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
527 else:
--> 528 x = self.layer1(x)
529 x = self.layer2(x)
530 x = self.layer3(x)
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/container.py:215, in Sequential.forward(self, input)
213 def forward(self, input):
214 for module in self:
--> 215 input = module(input)
216 return input
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/homebrew/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/homebrew/lib/python3.11/site-packages/timm/models/resnet.py:115, in BasicBlock.forward(self, x)
113 if self.downsample is not None:
114 shortcut = self.downsample(shortcut)
--> 115 x += shortcut
116 x = self.act2(x)
118 return x
RuntimeError: Output 0 of BackwardHookFunctionBackward is a view and is being modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by cloning the output of the custom Function.
I am doing an image classification interpretation using shap:
shap_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(cfg.device)
#Create SHAP explainer
explainer = shap.DeepExplainer(model, background)
# Load test images
test_images = [Image.open(path) for path in ex_paths]
# test_images = np.array(test_images)
test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(cfg.device)
# Get SHAP values
shap_values = explainer.shap_values(test_input)
# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])
shap.image_plot(shap_numpy, test_numpy,show=False)
shap_values = explainer.shap_values(test_input)
shows this error.
I have tried to remove inlace = True from ReLU layers with the following function, but the error is anyway there:
def ReLU_inplace_to_False(module):
for layer in module._modules.values():
if isinstance(layer, nn.ReLU):
layer.inplace = False
ReLU_inplace_to_False(layer)
I want to get an interpretation of the model. I am using a retrained resnet18 model from the time models:
cfg.n_classes = 2
cfg.backbone = 'resnet18'
model = timm.create_model(cfg.backbone,
pretrained = True,
num_classes = cfg.n_classes)
Minimal Reproducible Example
import shap
Traceback
No response
Expected Behavior
No response
Bug report checklist
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest release of shap.
- I have confirmed this bug exists on the master branch of shap.
- I'd be interested in making a PR to fix this bug
Installed Versions
shap.version = '0.44.0'
I am getting the same issue on shap version = 0.43.0
Check this issue: #3466
There is also a problem with x += shortcut
inplace operation in the resnet forward function. Not just the nn.ReLU(inplace=True)
Thanks for the bug report. This looks like a duplicate so let's track this one on #3466