pytorch/captum

LayerConductance not working with llama2

ThePuscher opened this issue ยท 0 comments

๐Ÿ› Bug

To Reproduce

Steps to reproduce the behavior:

Run the code below, which is trying to use layer conductance attribution for llama2.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import LayerConductance
import bitsandbytes as bnb

def load_model(model_name, bnb_config):
    n_gpus = torch.cuda.device_count()
    max_memory = "10000MB"

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        # device_map="cpu"
        device_map="auto", # dispatch efficiently the model on the available ressources
        max_memory = {i: max_memory for i in range(n_gpus)},
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)

    # Needed for LLaMA tokenizer
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

def create_bnb_config():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return bnb_config

model_name = "meta-llama/Llama-2-7b-chat-hf" 
bnb_config = create_bnb_config()
model, tokenizer = load_model(model_name, bnb_config)

layer = model.model.layers[-1]    
input_test = "The president of the USA is named"
inputs = tokenizer(input_test, return_tensors="pt").to("cuda:0")
input_ids = inputs["input_ids"].int()

layer_cond = LayerConductance(model, layer)
llama_att = layer_cond.attribute(input_ids, target=0) # first token

Expected behavior

No error raised.

Environment

Describe the environment used for Captum


 - Captum / PyTorch Version: 0.7.0 / 2.3.1
 - OS: Linux
 - How you installed Captum / PyTorch: pip
 - Python version: 3.10
 - CUDA/cuDNN version: 11.7
 - GPU models and configuration: GA102GL [A40]

Additional context

Stack trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[29], line 2
      1 layer_cond = LayerConductance(model, layer)
----> 2 llama_att = layer_cond.attribute(input_ids, target=target)

File ~/venv/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     40 @wraps(func)
     41 def wrapper(*args, **kwargs):
---> 42     return func(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:292, in LayerConductance.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
    277     attrs = _batch_attribution(
    278         self,
    279         num_examples,
   (...)
    288         attribute_to_layer_input=attribute_to_layer_input,
    289     )
    291 else:
--> 292     attrs = self._attribute(
    293         inputs=inputs,
    294         baselines=baselines,
    295         target=target,
    296         additional_forward_args=additional_forward_args,
    297         n_steps=n_steps,
    298         method=method,
    299         attribute_to_layer_input=attribute_to_layer_input,
    300     )
    302 is_layer_tuple = isinstance(attrs, tuple)
    303 attributions = attrs if is_layer_tuple else (attrs,)

File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:360, in LayerConductance._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, attribute_to_layer_input, step_sizes_and_alphas)
    356 expanded_target = _expand_target(target, n_steps + 1)
    358 # Conductance Gradients - Returns gradient of output with respect to
    359 # hidden layer and hidden layer evaluated at each input.
--> 360 (layer_gradients, layer_evals,) = compute_layer_gradients_and_eval(
    361     forward_fn=self.forward_func,
    362     layer=self.layer,
    363     inputs=scaled_features_tpl,
    364     additional_forward_args=input_additional_args,
    365     target_ind=expanded_target,
    366     device_ids=self.device_ids,
    367     attribute_to_layer_input=attribute_to_layer_input,
    368 )
    370 # Compute differences between consecutive evaluations of layer_eval.
    371 # This approximates the total input gradient of each step multiplied
    372 # by the step size.
    373 grad_diffs = tuple(
    374     layer_eval[num_examples:] - layer_eval[:-num_examples]
    375     for layer_eval in layer_evals
    376 )

File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:592, in compute_layer_gradients_and_eval(forward_fn, layer, inputs, target_ind, additional_forward_args, gradient_neuron_selector, device_ids, attribute_to_layer_input, output_fn)
    541 r"""
    542 Computes gradients of the output with respect to a given layer as well
    543 as the output evaluation of the layer for an arbitrary forward function
   (...)
    587         Target layer output for given input.
    588 """
    589 with torch.autograd.set_grad_enabled(True):
    590     # saved_layer is a dictionary mapping device to a tuple of
    591     # layer evaluations on that device.
--> 592     saved_layer, output = _forward_layer_distributed_eval(
    593         forward_fn,
    594         inputs,
    595         layer,
    596         target_ind=target_ind,
    597         additional_forward_args=additional_forward_args,
    598         attribute_to_layer_input=attribute_to_layer_input,
    599         forward_hook_with_return=True,
    600         require_layer_grads=True,
    601     )
    602     assert output[0].numel() == 1, (
    603         "Target not provided when necessary, cannot"
    604         " take gradient with respect to multiple outputs."
    605     )
    607     device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)

File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:294, in _forward_layer_distributed_eval(forward_fn, inputs, layer, target_ind, additional_forward_args, attribute_to_layer_input, forward_hook_with_return, require_layer_grads)
    290         else:
    291             all_hooks.append(
    292                 single_layer.register_forward_hook(hook_wrapper(single_layer))
    293             )
--> 294     output = _run_forward(
    295         forward_fn,
    296         inputs,
    297         target=target_ind,
    298         additional_forward_args=additional_forward_args,
    299     )
    300 finally:
    301     for hook in all_hooks:

File ~/venv/lib/python3.10/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
    528 inputs = _format_inputs(inputs)
    529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
    532     *(*inputs, *additional_forward_args)
    533     if additional_forward_args is not None
    534     else inputs
    535 )
    536 return _select_targets(output, target)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1174, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1171 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1173 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1174 outputs = self.model(
   1175     input_ids=input_ids,
   1176     attention_mask=attention_mask,
   1177     position_ids=position_ids,
   1178     past_key_values=past_key_values,
   1179     inputs_embeds=inputs_embeds,
   1180     use_cache=use_cache,
   1181     output_attentions=output_attentions,
   1182     output_hidden_states=output_hidden_states,
   1183     return_dict=return_dict,
   1184     cache_position=cache_position,
   1185 )
   1187 hidden_states = outputs[0]
   1188 if self.config.pretraining_tp > 1:

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:931, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    928     use_cache = False
    930 if inputs_embeds is None:
--> 931     inputs_embeds = self.embed_tokens(input_ids)
    933 return_legacy_cache = False
    934 if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
    162 def forward(self, input: Tensor) -> Tensor:
--> 163     return F.embedding(
    164         input, self.weight, self.padding_idx, self.max_norm,
    165         self.norm_type, self.scale_grad_by_freq, self.sparse)

File ~/venv/lib/python3.10/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2258     # Note [embedding_renorm set_grad_enabled]
   2259     # XXX: equivalent to
   2260     # with torch.no_grad():
   2261     #   torch.embedding_renorm_
   2262     # remove once script supports set_grad_enabled
   2263     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)