LayerConductance not working with llama2
ThePuscher opened this issue ยท 0 comments
ThePuscher commented
๐ Bug
To Reproduce
Steps to reproduce the behavior:
Run the code below, which is trying to use layer conductance attribution for llama2.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import LayerConductance
import bitsandbytes as bnb
def load_model(model_name, bnb_config):
n_gpus = torch.cuda.device_count()
max_memory = "10000MB"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
# device_map="cpu"
device_map="auto", # dispatch efficiently the model on the available ressources
max_memory = {i: max_memory for i in range(n_gpus)},
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
# Needed for LLaMA tokenizer
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def create_bnb_config():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return bnb_config
model_name = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = create_bnb_config()
model, tokenizer = load_model(model_name, bnb_config)
layer = model.model.layers[-1]
input_test = "The president of the USA is named"
inputs = tokenizer(input_test, return_tensors="pt").to("cuda:0")
input_ids = inputs["input_ids"].int()
layer_cond = LayerConductance(model, layer)
llama_att = layer_cond.attribute(input_ids, target=0) # first token
Expected behavior
No error raised.
Environment
Describe the environment used for Captum
- Captum / PyTorch Version: 0.7.0 / 2.3.1
- OS: Linux
- How you installed Captum / PyTorch: pip
- Python version: 3.10
- CUDA/cuDNN version: 11.7
- GPU models and configuration: GA102GL [A40]
Additional context
Stack trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[29], line 2
1 layer_cond = LayerConductance(model, layer)
----> 2 llama_att = layer_cond.attribute(input_ids, target=target)
File ~/venv/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:292, in LayerConductance.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
277 attrs = _batch_attribution(
278 self,
279 num_examples,
(...)
288 attribute_to_layer_input=attribute_to_layer_input,
289 )
291 else:
--> 292 attrs = self._attribute(
293 inputs=inputs,
294 baselines=baselines,
295 target=target,
296 additional_forward_args=additional_forward_args,
297 n_steps=n_steps,
298 method=method,
299 attribute_to_layer_input=attribute_to_layer_input,
300 )
302 is_layer_tuple = isinstance(attrs, tuple)
303 attributions = attrs if is_layer_tuple else (attrs,)
File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:360, in LayerConductance._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, attribute_to_layer_input, step_sizes_and_alphas)
356 expanded_target = _expand_target(target, n_steps + 1)
358 # Conductance Gradients - Returns gradient of output with respect to
359 # hidden layer and hidden layer evaluated at each input.
--> 360 (layer_gradients, layer_evals,) = compute_layer_gradients_and_eval(
361 forward_fn=self.forward_func,
362 layer=self.layer,
363 inputs=scaled_features_tpl,
364 additional_forward_args=input_additional_args,
365 target_ind=expanded_target,
366 device_ids=self.device_ids,
367 attribute_to_layer_input=attribute_to_layer_input,
368 )
370 # Compute differences between consecutive evaluations of layer_eval.
371 # This approximates the total input gradient of each step multiplied
372 # by the step size.
373 grad_diffs = tuple(
374 layer_eval[num_examples:] - layer_eval[:-num_examples]
375 for layer_eval in layer_evals
376 )
File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:592, in compute_layer_gradients_and_eval(forward_fn, layer, inputs, target_ind, additional_forward_args, gradient_neuron_selector, device_ids, attribute_to_layer_input, output_fn)
541 r"""
542 Computes gradients of the output with respect to a given layer as well
543 as the output evaluation of the layer for an arbitrary forward function
(...)
587 Target layer output for given input.
588 """
589 with torch.autograd.set_grad_enabled(True):
590 # saved_layer is a dictionary mapping device to a tuple of
591 # layer evaluations on that device.
--> 592 saved_layer, output = _forward_layer_distributed_eval(
593 forward_fn,
594 inputs,
595 layer,
596 target_ind=target_ind,
597 additional_forward_args=additional_forward_args,
598 attribute_to_layer_input=attribute_to_layer_input,
599 forward_hook_with_return=True,
600 require_layer_grads=True,
601 )
602 assert output[0].numel() == 1, (
603 "Target not provided when necessary, cannot"
604 " take gradient with respect to multiple outputs."
605 )
607 device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:294, in _forward_layer_distributed_eval(forward_fn, inputs, layer, target_ind, additional_forward_args, attribute_to_layer_input, forward_hook_with_return, require_layer_grads)
290 else:
291 all_hooks.append(
292 single_layer.register_forward_hook(hook_wrapper(single_layer))
293 )
--> 294 output = _run_forward(
295 forward_fn,
296 inputs,
297 target=target_ind,
298 additional_forward_args=additional_forward_args,
299 )
300 finally:
301 for hook in all_hooks:
File ~/venv/lib/python3.10/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
528 inputs = _format_inputs(inputs)
529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
532 *(*inputs, *additional_forward_args)
533 if additional_forward_args is not None
534 else inputs
535 )
536 return _select_targets(output, target)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1174, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1171 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1173 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1174 outputs = self.model(
1175 input_ids=input_ids,
1176 attention_mask=attention_mask,
1177 position_ids=position_ids,
1178 past_key_values=past_key_values,
1179 inputs_embeds=inputs_embeds,
1180 use_cache=use_cache,
1181 output_attentions=output_attentions,
1182 output_hidden_states=output_hidden_states,
1183 return_dict=return_dict,
1184 cache_position=cache_position,
1185 )
1187 hidden_states = outputs[0]
1188 if self.config.pretraining_tp > 1:
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:931, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
928 use_cache = False
930 if inputs_embeds is None:
--> 931 inputs_embeds = self.embed_tokens(input_ids)
933 return_legacy_cache = False
934 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
162 def forward(self, input: Tensor) -> Tensor:
--> 163 return F.embedding(
164 input, self.weight, self.padding_idx, self.max_norm,
165 self.norm_type, self.scale_grad_by_freq, self.sparse)
File ~/venv/lib/python3.10/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2258 # Note [embedding_renorm set_grad_enabled]
2259 # XXX: equivalent to
2260 # with torch.no_grad():
2261 # torch.embedding_renorm_
2262 # remove once script supports set_grad_enabled
2263 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)