shap/shap

BUG: ERROR USING LLAMA-2

MarioRicoIbanez opened this issue · 15 comments

Issue Description

Im trying to analyze the explainability of SOTA LLMs such as llama-2. But when trying to use SHAP with these models I am getting the following error.

Minimal Reproducible Example

from transformers import AutoModelForCausalLM, AutoTokenizer

import shap

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").cuda()

# set model decoder to true
model.config.is_decoder = True
# set text-generation params under task_specific_params
#create first the dict
model.config.task_specific_params = {}
model.config.task_specific_params["text-generation"] = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 50,
    "no_repeat_ngram_size": 2,
    
}

s = ["I enjoy walking with my cute dog"]
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(s)

Traceback

{
	"name": "ValueError",
	"message": "not enough values to unpack (expected 2, got 1)",
	"stack": "---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[26], line 17
     15 s = [\"I enjoy walking with my cute dog\"]
     16 explainer = shap.Explainer(model, tokenizer)
---> 17 shap_values = explainer(s)

File /usr/local/lib/python3.8/dist-packages/shap/explainers/_partition.py:129, in PartitionExplainer.__call__(self, max_evals, fixed_context, main_effects, error_bounds, batch_size, outputs, silent, *args)
    125 def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size=\"auto\",
    126              outputs=None, silent=False):
    127     \"\"\" Explain the output of the model on the given arguments.
    128     \"\"\"
--> 129     return super().__call__(
    130         *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
    131         outputs=outputs, silent=silent
    132     )

File /usr/local/lib/python3.8/dist-packages/shap/explainers/_explainer.py:267, in Explainer.__call__(self, max_evals, main_effects, error_bounds, batch_size, outputs, silent, *args, **kwargs)
    265     feature_names = [[] for _ in range(len(args))]
    266 for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+\" explainer\", silent):
--> 267     row_result = self.explain_row(
    268         *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
    269         batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
    270     )
    271     values.append(row_result.get(\"values\", None))
    272     output_indices.append(row_result.get(\"output_indices\", None))

File /usr/local/lib/python3.8/dist-packages/shap/explainers/_partition.py:154, in PartitionExplainer.explain_row(self, max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context, *row_args)
    152 # if not fixed background or no base value assigned then compute base value for a row
    153 if self._curr_base_value is None or not getattr(self.masker, \"fixed_background\", False):
--> 154     self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model what the baseline is
    155 f11 = fm(~m00.reshape(1, -1))[0]
    157 if callable(self.masker.clustering):

File /usr/local/lib/python3.8/dist-packages/shap/utils/_masked_model.py:69, in MaskedModel.__call__(self, masks, zero_index, batch_size)
     66         return self._full_masking_call(full_masks, zero_index=zero_index, batch_size=batch_size)
     68 else:
---> 69     return self._full_masking_call(masks, batch_size=batch_size)

File /usr/local/lib/python3.8/dist-packages/shap/utils/_masked_model.py:146, in MaskedModel._full_masking_call(self, masks, zero_index, batch_size)
    143         all_masked_inputs[i].append(v)
    145 joined_masked_inputs = tuple([np.concatenate(v) for v in all_masked_inputs])
--> 146 outputs = self.model(*joined_masked_inputs)
    147 _assert_output_input_match(joined_masked_inputs, outputs)
    148 all_outputs.append(outputs)

File /usr/local/lib/python3.8/dist-packages/shap/models/_model.py:28, in Model.__call__(self, *args)
     27 def __call__(self, *args):
---> 28     out = self.inner_model(*args)
     29     is_tensor = safe_isinstance(out, \"torch.Tensor\")
     30     out = out.cpu().detach().numpy() if is_tensor else np.array(out)

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1533, in Module._call_impl(self, *args, **kwargs)
   1528 # If we don't have any hooks, we want to skip the rest of the logic in
   1529 # this function, and just call forward.  It's slow for dynamo to guard on the state
   1530 # of all these hook dicts individually, so instead it can guard on 2 bools and we just
   1531 # have to promise to keep them up to date when hooks are added or removed via official means.
   1532 if not self._has_hooks and not _has_global_hooks:
-> 1533     return forward_call(*args, **kwargs)
   1534 # Do not call functions when jit is used
   1535 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py:1183, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1180 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1182 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1183 outputs = self.model(
   1184     input_ids=input_ids,
   1185     attention_mask=attention_mask,
   1186     position_ids=position_ids,
   1187     past_key_values=past_key_values,
   1188     inputs_embeds=inputs_embeds,
   1189     use_cache=use_cache,
   1190     output_attentions=output_attentions,
   1191     output_hidden_states=output_hidden_states,
   1192     return_dict=return_dict,
   1193 )
   1195 hidden_states = outputs[0]
   1196 if self.config.pretraining_tp > 1:

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1533, in Module._call_impl(self, *args, **kwargs)
   1528 # If we don't have any hooks, we want to skip the rest of the logic in
   1529 # this function, and just call forward.  It's slow for dynamo to guard on the state
   1530 # of all these hook dicts individually, so instead it can guard on 2 bools and we just
   1531 # have to promise to keep them up to date when hooks are added or removed via official means.
   1532 if not self._has_hooks and not _has_global_hooks:
-> 1533     return forward_call(*args, **kwargs)
   1534 # Do not call functions when jit is used
   1535 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py:999, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    997     raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")
    998 elif input_ids is not None:
--> 999     batch_size, seq_length = input_ids.shape[:2]
   1000 elif inputs_embeds is not None:
   1001     batch_size, seq_length = inputs_embeds.shape[:2]

ValueError: not enough values to unpack (expected 2, got 1)"
}

Expected Behavior

No response

Bug report checklist

  • I have checked that this issue has not already been reported.
  • I have confirmed this bug exists on the latest release of shap.
  • I have confirmed this bug exists on the master branch of shap.
  • I'd be interested in making a PR to fix this bug

Installed Versions

Python 3.8.10 scikit-learn==1.2.2 accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.37.2 trl==0.4.7 datasets==2.16.1 mlflow==2.10.0 pydantic==1.10.14 typing_extensions==4.9.0

s = ["I enjoy walking with my cute dog"]
gen_dict = dict(
max_new_tokens=100,
num_beams=5,
renormalize_logits=True,
no_repeat_ngram_size=8,
)

model.config.task_specific_params = dict()
model.config.task_specific_params["text-generation"] = gen_dict
shap_model = shap.models.TeacherForcing(model, tokenizer)

masker = shap.maskers.Text(tokenizer, mask_token="...", collapse_mask_token=True)
explainer = shap.Explainer(shap_model, tokenizer)
shap_values = explainer(s)

this solution should solve Your problem with Llama-2, and with Mistral.

Thanks for the report, I requested access for llama2 now. Will post here once there are any updates.

Now is partially working. I get the values array all zeros

.values =
array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

.base_values =
array([[-0.52342548, -2.88410334, -0.36660795, 0.95767114, -0.24581238,
-1.7873968 , -0.03110165, 6.36779252, 0.35044359, 6.19219956,
2.58561948, 7.491578 , 3.81039219, -2.31506661, -0.42209612,
-0.62833234, -0.64307879, -1.80665099, -1.17742844, 0.66211251,
-0.27008128, 6.25576132, 4.52850612, -1.31015818, -0.45799231,
10.00574646, 0.50833428, -0.42837653, 4.29419062, 4.06030859,
-1.15709111, -0.20367953, 5.86984239, 4.13385361, 2.36138941,
-0.08768206, 3.2889124 , 0.68570033, -0.53387673, 0.55577215,
-0.35025047, 3.82609343, -0.75910988, 2.56892822, 2.24339371,
3.08884504, 0.6789584 , 0.73464042, 1.60391795, 2.63059456,
5.00190821, 7.1968913 , 1.43071471, 2.62828756, 3.7208354 ,
11.1741379 , 6.9844757 , 0.30599576, 2.32297348, 0.70061408,
1.50329472, 5.85171772, 0.6600345 , 1.56481051, 4.1168472 ,
6.36192085, -0.81794184, 2.52507464, 4.35465319, -0.32329904,
1.68587773, -1.32010292, 2.59065567]])

.data =
(array(['', 'I', ' enjoy', ' walking', ' with', ' my', ' c', 'ute', ' do

I guess the PR #3578 fixes this

Uhmm I saw it just now, but I'm pretty new to this and when I try to install this PR it tells me that I need to have python 3.9. Maybe I'm doing something wrong, how is the way to use this PR?

Did not test it, but I according to SO should work:

pip install https://github.com/costrau/shap/archive/fix-transformers.zip

Not working :((

pip install https://github.com/costrau/shap/archive/fix-transformers.zip
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting https://github.com/costrau/shap/archive/fix-transformers.zip
  Downloading https://github.com/costrau/shap/archive/fix-transformers.zip
     - 159.5 MB 5.7 MB/s 0:00:30
  Installing build dependencies ... done
  Getting requirements to build wheel ... error
  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [119 lines of output]
      Attempting to build SHAP: with_binary=True, with_cuda=True (Attempt 1)
      NVCC ==>  /usr/local/cuda/bin/nvcc
      Compiling cuda extension, calling nvcc with arguments:
      ['/usr/local/cuda/bin/nvcc', '-allow-unsupported-compiler', 'shap/cext/_cext_gpu.cu', '-lib', '-o', 'build/lib_cext_gpu.a', '-Xcompiler', '-fPIC', '--include-path', '/usr/include/python3.8', '--std', 'c++14', '--expt-extended-lambda', '--expt-relaxed-constexpr', '-gencode=arch=compute_60,code=sm_60', '-gencode=arch=compute_70,code=sm_70', '-gencode=arch=compute_75,code=sm_75', '-gencode=arch=compute_75,code=compute_75', '-gencode=arch=compute_80,code=sm_80']
      Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      WARNING: Could not compile cuda extensions.
      Retrying SHAP build without cuda extension...
      Attempting to build SHAP: with_binary=True, with_cuda=False (Attempt 2)
      Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      WARNING: The C extension could not be compiled, sklearn tree models not supported.
      Retrying SHAP build without binary extension...
      Attempting to build SHAP: with_binary=False, with_cuda=False (Attempt 3)
      Exception occurred during setup, setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      ERROR: Failed to build!
      Traceback (most recent call last):
        File "<string>", line 144, in try_run_setup
        File "<string>", line 134, in run_setup
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
          return distutils.core.setup(**attrs)
        File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
          _setup_distribution = dist = klass(attrs)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
          _Distribution.__init__(self, dist_attrs)
        File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
          self.finalize_options()
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
          ep(self)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
          _assign_version(dist, config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
          _version_missing(config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
          raise LookupError(
      LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      
      During handling of the above exception, another exception occurred:
      
      Traceback (most recent call last):
        File "<string>", line 144, in try_run_setup
        File "<string>", line 134, in run_setup
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
          return distutils.core.setup(**attrs)
        File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
          _setup_distribution = dist = klass(attrs)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
          _Distribution.__init__(self, dist_attrs)
        File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
          self.finalize_options()
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
          ep(self)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
          _assign_version(dist, config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
          _version_missing(config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
          raise LookupError(
      LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      
      During handling of the above exception, another exception occurred:
      
      Traceback (most recent call last):
        File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
          main()
        File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/usr/local/lib/python3.8/dist-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 325, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 295, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/build_meta.py", line 311, in run_setup
          exec(code, locals())
        File "<string>", line 165, in <module>
        File "<string>", line 160, in try_run_setup
        File "<string>", line 160, in try_run_setup
        File "<string>", line 144, in try_run_setup
        File "<string>", line 134, in run_setup
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/__init__.py", line 104, in setup
          return distutils.core.setup(**attrs)
        File "/usr/lib/python3.8/distutils/core.py", line 108, in setup
          _setup_distribution = dist = klass(attrs)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 307, in __init__
          _Distribution.__init__(self, dist_attrs)
        File "/usr/lib/python3.8/distutils/dist.py", line 292, in __init__
          self.finalize_options()
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools/dist.py", line 658, in finalize_options
          ep(self)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 121, in infer_version
          _assign_version(dist, config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_integration/setuptools.py", line 56, in _assign_version
          _version_missing(config)
        File "/tmp/pip-build-env-b_71bw3z/overlay/lib/python3.8/site-packages/setuptools_scm/_get_version_impl.py", line 112, in _version_missing
          raise LookupError(
      LookupError: setuptools-scm was unable to detect version for /tmp/pip-req-build-er8m9vcy.
      
      Make sure you're either building from a fully intact git repository or PyPI tarballs. Most other sources (such as GitHub's tarballs, a git checkout without the .git folder) don't contain the necessary metadata and will not work.
      
      For example, if you're using pip, instead of https://github.com/user/proj/archive/master.zip use git+https://github.com/user/proj.git#egg=proj
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.

I moved to python 3.10 and used the following structure of pip install
pip3 install git+https://github.com/costrau/shap.git@fix-transformers
It finally worked

BTW, do you know what mechanism is shap using to see attentions?

Is this resolving your bug? What do you mean with see attentions? For transformer models shap basically just does inference and extracts the logits from there.

Yep, what I wrote resolved my bug.

Yeah, it extracts the logits from there, but it just shows the raw logits of the layers doing an average or something. Because a llama has 32 heads and 32 layers, it has 32x32 different attention matrices.

Shap does not provide any information on model internals just how the model uses the input to generate the output. If you are interested in that maybe you can get some results from the pytorch/captum package. I once read a paper about saliency maps that basically does what you want, but captum's implementation of that also just seems like giving you the attributions of the inputs

Yes, so if shap is representing how the model is using the input, it's probably using attentions, right?
Otherwise, thanks for the idea of using the captom package, I didn't know it existed, I'll look at it

In this case shap is using the model, the model uses attention layers. But shap has no idea about model internals if you explain transformer models

FYI, Captum now has included KernelSHAP for explaining LLMs. Read more here: https://captum.ai/tutorials/Llama2_LLM_Attribution

Yes, I ended up using captum and works perfectly. Thank you so much though.