ndif-team/nnsight

Can't run llama architectures

Closed this issue · 6 comments

I can't run nnsight on llama models. I get a runtime error RuntimeError: User specified an unsupported autocast device_type 'meta'
MWE:

from nnsight import LanguageModel
model = LanguageModel('Maykeye/TinyLLama-v0',device_map='auto')
prompt = "The french translation for 'hello' is:\n"
with model.trace(prompt) as trace:
    pass

I tested:

  • using cpu / cuda instead of auto
  • using github main and dev branches instead of pip
  • on colab and on my school cluster
  • On croissantllm (also llama based)

Full stack trace:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-5-c8af8fe5fc7f>](https://localhost:8080/#) in <cell line: 4>()
      2 model = LanguageModel('Maykeye/TinyLLama-v0',device_map='cuda:0')
      3 prompts = "The french translation for 'hello' is:\n"
----> 4 with model.trace(prompts) as trace:
      5     pass

20 frames
[/usr/local/lib/python3.10/dist-packages/nnsight/models/NNsightModel.py](https://localhost:8080/#) in trace(self, trace, invoker_args, scan, *inputs, **kwargs)
    194 
    195             # Otherwise open an invoker context with the give args.
--> 196             runner.invoke(*inputs, **invoker_args).__enter__()
    197 
    198         # If trace is False, you had to have provided an input.

[/usr/local/lib/python3.10/dist-packages/nnsight/contexts/Invoker.py](https://localhost:8080/#) in __enter__(self)
     65             ) as fake_mode:
     66                 with FakeCopyMode(fake_mode):
---> 67                     self.tracer._model._execute(
     68                         *copy.deepcopy(self.inputs),
     69                         **copy.deepcopy(self.tracer._kwargs),

[/usr/local/lib/python3.10/dist-packages/nnsight/models/mixins/Generation.py](https://localhost:8080/#) in _execute(self, prepared_inputs, generate, *args, **kwargs)
     19             return self._execute_generate(prepared_inputs, *args, **kwargs)
     20 
---> 21         return self._execute_forward(prepared_inputs, *args, **kwargs)
     22 
     23     def _scan(

[/usr/local/lib/python3.10/dist-packages/nnsight/models/LanguageModel.py](https://localhost:8080/#) in _execute_forward(self, prepared_inputs, *args, **kwargs)
    274         device = next(self._model.parameters()).device
    275 
--> 276         return self._model(
    277             *args,
    278             **prepared_inputs.to(device),

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1174 
   1175         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1176         outputs = self.model(
   1177             input_ids=input_ids,
   1178             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1017                 )
   1018             else:
-> 1019                 layer_outputs = decoder_layer(
   1020                     hidden_states,
   1021                     attention_mask=causal_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    738 
    739         # Self Attention
--> 740         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    741             hidden_states=hidden_states,
    742             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    359 
    360         past_key_value = getattr(self, "past_key_value", past_key_value)
--> 361         cos, sin = self.rotary_emb(value_states, position_ids)
    362         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    363 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1566                 args = bw_hook.setup_input_hook(args)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:
   1570                 for hook_id, hook in (

[/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

[/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py](https://localhost:8080/#) in forward(self, x, position_ids, seq_len)
    139         device_type = x.device.type
    140         device_type = device_type if isinstance(device_type, str) else "cpu"
--> 141         with torch.autocast(device_type=device_type, enabled=False):
    142             freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    143             emb = torch.cat((freqs, freqs), dim=-1)

[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in __init__(self, device_type, dtype, enabled, cache_enabled)
    239             self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
    240         else:
--> 241             raise RuntimeError(
    242                 f"User specified an unsupported autocast device_type '{self.device}'"
    243             )

RuntimeError: User specified an unsupported autocast device_type 'meta'

Running the model with HF directly works:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0")
model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0", device_map="cuda")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))

Same error with Code Llama

@arjunguha Can you try upgrading the transformers package? See the NDIF discord for more.

Yup fixed. Works with non-git transformers and nnight>0.2

Another workaround suggested by Jaden:
Just dispatch the model on init so its not on the 'meta' device:

from nnsight import LanguageModel
model = LanguageModel('Maykeye/TinyLLama-v0',device_map='auto', dispatch=True)
prompt = "The french translation for 'hello' is:\n"
with model.trace(prompt) as trace:
    pass

Another workaround (for those who want to run the model remotely for example) is to do with model.trace(prompt, scan=False)

Disabling scan doesn't seems like a big deal :

scan: if to execute the model using FakeTensor in order to update the potential sizes/dtypes of all modules’ Envoys’ inputs/outputs as well as validate things work correctly. Scanning is not free computation wise so you may want to turn this to false when running in a loop. When making interventions, you made get shape errors if scan is false as it validates operations based on shapes so for looped calls where shapes are consistent, you may want to have scan=True for the first loop. Defaults to True.