BUG: SHAP DeepExplainer cannot get SHAP values from TorchScript model
DarrelYee opened this issue · 1 comments
Issue Description
DeepExplainer
currently seems unable to handle a Pytorch model loaded from TorchScript, and will throw RuntimeError: register_forward_hook is not supported on ScriptModules
. DeepExplainer
has no problem with the original source nn.Module
, but throws the error when its converted to ScriptModule
. Reconstituting a handler nn.Module
from ScriptModule
gives the same error.
I'm using 0.42.1, but the problem occurs on the latest release as well.
Minimal Reproducible Example
# Setup a dummy model and convert to TorchScript
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
simple_model = SimpleModel()
ts = torch.jit.script(simple_model)
eg_input_data = torch.arange(0, 10, 0.1, dtype = torch.float32).reshape((10,10))
# Running this gives RuntimeError: register_forward_hook is not supported on ScriptModules
explainer = shap.DeepExplainer(simple_model, eg_input_data)
print(explainer.shap_values(eg_input_data))
# Reconstituting the ts as an nn.Module gives the same error
class ReconModel(nn.Module):
def __init__(self, ts):
super().__init__()
self.ts = ts
def forward(self, x):
return self.ts(x)
recon_model = ReconModel(ts)
explainer = shap.DeepExplainer(recon_model, torch.arange(0, 1, 0.1, dtype = torch.float32).reshape((-1,10)))
print(explainer.shap_values(eg_input_data))
Traceback
RuntimeError Traceback (most recent call last)
<ipython-input-43-fd62703e7cd4> in <module>
28
29 explainer = shap.DeepExplainer(recon_model, torch.arange(0, 1, 0.1, dtype = torch.float32).reshape((-1,10)))
---> 30 print(explainer.shap_values(eg_input_data))
~/.local/lib/python3.7/site-packages/shap/explainers/_deep/__init__.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
122 were chosen as "top".
123 """
--> 124 return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
164
165 # add the gradient handles
--> 166 handles = self.add_handles(self.model, add_interim_values, deeplift_grad)
167 if self.interim:
168 self.add_target_handle(self.layer)
~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
77 if model_children:
78 for child in model_children:
---> 79 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
80 else: # leaves
81 handles_list.append(model.register_forward_hook(forward_handle))
~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
77 if model_children:
78 for child in model_children:
---> 79 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
80 else: # leaves
81 handles_list.append(model.register_forward_hook(forward_handle))
~/.local/lib/python3.7/site-packages/shap/explainers/_deep/deep_pytorch.py in add_handles(self, model, forward_handle, backward_handle)
79 handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
80 else: # leaves
---> 81 handles_list.append(model.register_forward_hook(forward_handle))
82 handles_list.append(model.register_backward_hook(backward_handle))
83 return handles_list
~/.local/lib/python3.7/site-packages/torch/jit/_script.py in fail(self, *args, **kwargs)
941 def _make_fail(name):
942 def fail(self, *args, **kwargs):
--> 943 raise RuntimeError(name + " is not supported on ScriptModules")
944
945 return fail
RuntimeError: register_forward_hook is not supported on ScriptModules
Expected Behavior
No response
Bug report checklist
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest release of shap.
- I have confirmed this bug exists on the master branch of shap.
- I'd be interested in making a PR to fix this bug
Installed Versions
0.42.1
Thanks for reporting the issue.
This seems like an inherent limitation of pytorch. We need to overwrite the gradients of the model and do this via hooks. If this is not possible with TorchScript models (or if there is no known acceptable workaround) we won't support this.
I leave this open for a while but if there is no further progress, we'll close this issue as a NO FIX.
Edit: Two ideas to check:
- does captum support TorchScript models?
- we should at least throw an error to explain why we don't support this.