BUG: ERROR USING LLAMA-2
MarioRicoIbanez opened this issue · 15 comments
Issue Description
Im trying to analyze the explainability of SOTA LLMs such as llama-2. But when trying to use SHAP with these models I am getting the following error.
Minimal Reproducible Example
from transformers import AutoModelForCausalLM, AutoTokenizer
import shap
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").cuda()
# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
#create first the dict
model.config.task_specific_params = {}
model.config.task_specific_params["text-generation"] = {
"do_sample": True,
"max_length": 50,
"temperature": 0.7,
"top_k": 50,
"no_repeat_ngram_size": 2,
}
s = ["I enjoy walking with my cute dog"]
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)
Traceback
{
"name": "ValueError",
"message": "not enough values to unpack (expected 2, got 1)",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[26], line 17
15 s = [\"I enjoy walking with my cute dog\"]
16 explainer = shap.Explainer(model, tokenizer)
---> 17 shap_values = explainer(s)
File /usr/local/lib/python3.8/dist-packages/shap/explainers/_partition.py:129, in PartitionExplainer.__call__(self, max_evals, fixed_context, main_effects, error_bounds, batch_size, outputs, silent, *args)
125 def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size=\"auto\",
126 outputs=None, silent=False):
127 \"\"\" Explain the output of the model on the given arguments.
128 \"\"\"
--> 129 return super().__call__(
130 *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
131 outputs=outputs, silent=silent
132 )
File /usr/local/lib/python3.8/dist-packages/shap/explainers/_explainer.py:267, in Explainer.__call__(self, max_evals, main_effects, error_bounds, batch_size, outputs, silent, *args, **kwargs)
265 feature_names = [[] for _ in range(len(args))]
266 for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+\" explainer\", silent):
--> 267 row_result = self.explain_row(
268 *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
269 batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
270 )
271 values.append(row_result.get(\"values\", None))
272 output_indices.append(row_result.get(\"output_indices\", None))
File /usr/local/lib/python3.8/dist-packages/shap/explainers/_partition.py:154, in PartitionExplainer.explain_row(self, max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context, *row_args)
152 # if not fixed background or no base value assigned then compute base value for a row
153 if self._curr_base_value is None or not getattr(self.masker, \"fixed_background\", False):
--> 154 self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model what the baseline is
155 f11 = fm(~m00.reshape(1, -1))[0]
157 if callable(self.masker.clustering):
File /usr/local/lib/python3.8/dist-packages/shap/utils/_masked_model.py:69, in MaskedModel.__call__(self, masks, zero_index, batch_size)
66 return self._full_masking_call(full_masks, zero_index=zero_index, batch_size=batch_size)
68 else:
---> 69 return self._full_masking_call(masks, batch_size=batch_size)
File /usr/local/lib/python3.8/dist-packages/shap/utils/_masked_model.py:146, in MaskedModel._full_masking_call(self, masks, zero_index, batch_size)
143 all_masked_inputs[i].append(v)
145 joined_masked_inputs = tuple([np.concatenate(v) for v in all_masked_inputs])
--> 146 outputs = self.model(*joined_masked_inputs)
147 _assert_output_input_match(joined_masked_inputs, outputs)
148 all_outputs.append(outputs)
File /usr/local/lib/python3.8/dist-packages/shap/models/_model.py:28, in Model.__call__(self, *args)
27 def __call__(self, *args):
---> 28 out = self.inner_model(*args)
29 is_tensor = safe_isinstance(out, \"torch.Tensor\")
30 out = out.cpu().detach().numpy() if is_tensor else np.array(out)
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1533, in Module._call_impl(self, *args, **kwargs)
1528 # If we don't have any hooks, we want to skip the rest of the logic in
1529 # this function, and just call forward. It's slow for dynamo to guard on the state
1530 # of all these hook dicts individually, so instead it can guard on 2 bools and we just
1531 # have to promise to keep them up to date when hooks are added or removed via official means.
1532 if not self._has_hooks and not _has_global_hooks:
-> 1533 return forward_call(*args, **kwargs)
1534 # Do not call functions when jit is used
1535 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py:1183, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1180 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1182 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1183 outputs = self.model(
1184 input_ids=input_ids,
1185 attention_mask=attention_mask,
1186 position_ids=position_ids,
1187 past_key_values=past_key_values,
1188 inputs_embeds=inputs_embeds,
1189 use_cache=use_cache,
1190 output_attentions=output_attentions,
1191 output_hidden_states=output_hidden_states,
1192 return_dict=return_dict,
1193 )
1195 hidden_states = outputs[0]
1196 if self.config.pretraining_tp > 1:
File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1533, in Module._call_impl(self, *args, **kwargs)
1528 # If we don't have any hooks, we want to skip the rest of the logic in
1529 # this function, and just call forward. It's slow for dynamo to guard on the state
1530 # of all these hook dicts individually, so instead it can guard on 2 bools and we just
1531 # have to promise to keep them up to date when hooks are added or removed via official means.
1532 if not self._has_hooks and not _has_global_hooks:
-> 1533 return forward_call(*args, **kwargs)
1534 # Do not call functions when jit is used
1535 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py:999, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
997 raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")
998 elif input_ids is not None:
--> 999 batch_size, seq_length = input_ids.shape[:2]
1000 elif inputs_embeds is not None:
1001 batch_size, seq_length = inputs_embeds.shape[:2]
ValueError: not enough values to unpack (expected 2, got 1)"
}
Expected Behavior
No response
Bug report checklist
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest release of shap.
- I have confirmed this bug exists on the master branch of shap.
- I'd be interested in making a PR to fix this bug
Installed Versions
Python 3.8.10 scikit-learn==1.2.2 accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.37.2 trl==0.4.7 datasets==2.16.1 mlflow==2.10.0 pydantic==1.10.14 typing_extensions==4.9.0
s = ["I enjoy walking with my cute dog"]
gen_dict = dict(
max_new_tokens=100,
num_beams=5,
renormalize_logits=True,
no_repeat_ngram_size=8,
)
model.config.task_specific_params = dict()
model.config.task_specific_params["text-generation"] = gen_dict
shap_model = shap.models.TeacherForcing(model, tokenizer)
masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
explainer = shap.Explainer(shap_model, tokenizer)
shap_values = explainer(s)
this solution should solve Your problem with Llama-2, and with Mistral.
Thanks for the report, I requested access for llama2 now. Will post here once there are any updates.
Now is partially working. I get the values array all zeros
.values =
array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
.base_values =
array([[-0.52342548, -2.88410334, -0.36660795, 0.95767114, -0.24581238,
-1.7873968 , -0.03110165, 6.36779252, 0.35044359, 6.19219956,
2.58561948, 7.491578 , 3.81039219, -2.31506661, -0.42209612,
-0.62833234, -0.64307879, -1.80665099, -1.17742844, 0.66211251,
-0.27008128, 6.25576132, 4.52850612, -1.31015818, -0.45799231,
10.00574646, 0.50833428, -0.42837653, 4.29419062, 4.06030859,
-1.15709111, -0.20367953, 5.86984239, 4.13385361, 2.36138941,
-0.08768206, 3.2889124 , 0.68570033, -0.53387673, 0.55577215,
-0.35025047, 3.82609343, -0.75910988, 2.56892822, 2.24339371,
3.08884504, 0.6789584 , 0.73464042, 1.60391795, 2.63059456,
5.00190821, 7.1968913 , 1.43071471, 2.62828756, 3.7208354 ,
11.1741379 , 6.9844757 , 0.30599576, 2.32297348, 0.70061408,
1.50329472, 5.85171772, 0.6600345 , 1.56481051, 4.1168472 ,
6.36192085, -0.81794184, 2.52507464, 4.35465319, -0.32329904,
1.68587773, -1.32010292, 2.59065567]])
.data =
(array(['', 'I', ' enjoy', ' walking', ' with', ' my', ' c', 'ute', ' do
I guess the PR #3578 fixes this
Uhmm I saw it just now, but I'm pretty new to this and when I try to install this PR it tells me that I need to have python 3.9. Maybe I'm doing something wrong, how is the way to use this PR?
Did not test it, but I according to SO should work:
pip install https://github.com/costrau/shap/archive/fix-transformers.zip
Not working :((
pip install https://github.com/costrau/shap/archive/fix-transformers.zip
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting https://github.com/costrau/shap/archive/fix-transformers.zip
Downloading https://github.com/costrau/shap/archive/fix-transformers.zip
- 159.5 MB 5.7 MB/s 0:00:30
Installing build dependencies ... done
Getting requirements to build wheel ... error
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> [119 lines of output]
Attempting to build SHAP: with_binary=True, with_cuda=True (Attempt 1)
NVCC ==> /usr/local/cuda/bin/nvcc
Compiling cuda extension, calling nvcc with arguments:
['/usr/local/cuda/bin/nvcc', '-allow-unsupported-compiler', 'shap/cext/_cext_gpu.cu', '-lib', '-o', 'build/lib_cext_gpu.a', '-Xcompiler', '-fPIC', '--include-path', '/usr/include/python3.8', '--std', 'c++14', '--expt-extended-lambda', '--expt-relaxed-constexpr', '-gencode=arch=compute_60,code=sm_60', '-gencode=arch=compute_70,code=sm_70', '-gencode=arch=compute_75,code=sm_75', '-gencode=arch=compute_75,code=compute_75', '-gencode=arch=compute_80,code=sm_80']
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
WARNING: Could not compile cuda extensions.
Retrying SHAP build without cuda extension...
Attempting to build SHAP: with_binary=True, with_cuda=False (Attempt 2)
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
WARNING: The C extension could not be compiled, sklearn tree models not supported.
Retrying SHAP build without binary extension...
Attempting to build SHAP: with_binary=False, with_cuda=False (Attempt 3)
Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
ERROR: Failed to build!
Traceback (most recent call last):
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
main()
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
json_out['return_val'] = hook(**hook_input['kwargs'])
File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 118, in get_requires_for_build_wheel
return hook(config_settings)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=['wheel'])
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 295, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 311, in run_setup
exec(code, locals())
File "<string>", line 165, in <module>
File "<string>", line 160, in try_run_setup
File "<string>", line 160, in try_run_setup
File "<string>", line 144, in try_run_setup
File "<string>", line 134, in run_setup
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
_setup_distribution = dist = klass(attrs)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
_Distribution.__init__(self, dist_attrs)
File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
self.finalize_options()
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
ep(self)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
_assign_version(dist, config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
_version_missing(config)
File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
raise LookupError(
LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
I moved to python 3.10 and used the following structure of pip install
pip3 install git+https://github.com/costrau/shap.git@fix-transformers
It finally worked
BTW, do you know what mechanism is shap using to see attentions?
Is this resolving your bug? What do you mean with see attentions
? For transformer models shap basically just does inference and extracts the logits from there.
Yep, what I wrote resolved my bug.
Yeah, it extracts the logits from there, but it just shows the raw logits of the layers doing an average or something. Because a llama has 32 heads and 32 layers, it has 32x32 different attention matrices.
Shap does not provide any information on model internals just how the model uses the input to generate the output. If you are interested in that maybe you can get some results from the pytorch/captum package. I once read a paper about saliency maps that basically does what you want, but captum's implementation of that also just seems like giving you the attributions of the inputs
Yes, so if shap is representing how the model is using the input, it's probably using attentions, right?
Otherwise, thanks for the idea of using the captom package, I didn't know it existed, I'll look at it
In this case shap is using the model, the model uses attention layers. But shap has no idea about model internals if you explain transformer models
FYI, Captum now has included KernelSHAP for explaining LLMs. Read more here: https://captum.ai/tutorials/Llama2_LLM_Attribution
Yes, I ended up using captum and works perfectly. Thank you so much though.