Can't preprocess dataset using meta-llama/Meta-Llama-3.1-8B model
ohmeow opened this issue · 3 comments
ohmeow commented
Please check that this issue hasn't been reported before.
- I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I expected to have a pre-processed dataset after running python -m axolotl.cli.preprocess
Current behaviour
I get this error:
NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit
Full trace:
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/wgilliam/development/projects/axolotl/src/axolotl/cli/preprocess.py", line 96, in <module>
fire.Fire(do_cli)
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/development/projects/axolotl/src/axolotl/cli/preprocess.py", line 85, in do_cli
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
return model_class.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3788, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1068, in __init__
self.model = LlamaModel(config)
^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 845, in __init__
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 845, in <listcomp>
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 632, in __init__
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 306, in __init__
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 119, in __init__
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/transformers/modeling_rope_utils.py", line 330, in _compute_llama3_parameters
if wavelen < high_freq_wavelen:
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wgilliam/mambaforge/envs/axolotl/lib/python3.11/site-packages/torch/utils/_device.py", line 78, in __torch_function__
return func(*args, **kwargs)
```
### Steps to reproduce
Run > `python -m axolotl.cli.preprocess`
### Config yaml
```yaml
base_model: "meta-llama/Meta-Llama-3.1-8B"
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
seed: 9
data_seed: 9
hub_model_id:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
datasets:
dataset_prepared_path: data/last_run_prepared
val_set_size: 0.05
output_dir: outputs
sequence_len: 3072
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: false
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
adapter: qlora
load_in_8bit: false
load_in_4bit: true
strict: false
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
# NOTE: If you add token you will need to save thse lora modules
# lora_modules_to_save:
# - embed_tokens
# - lm_head
gradient_accumulation_steps: 4 # 8
micro_batch_size: 2 #1
eval_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 3e-5 #0.0002
# max_grad_norm: 1.0
# adam_beta2: 0.95
# adam_epsilon: 0.00001
# save_total_limit: 12
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
# max_steps: 100
# warmup_steps: 100 #10
warmup_ratio: 0.2
evals_per_epoch: 4 #8
# eval_steps: 10
eval_table_size:
eval_max_new_tokens: 512 #128
save_total_limit: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
# NOTE: If you add tokens, in addition to updating the lora config (see above), you'll
# likely need to reduce your overall batch size if training on a GPU poor rig
# tokens:
# - <function-definitions>
# - </function-definitions>
# - <function-thoughts>
# - </function-thoughts>
# - <function-calls>
# - </function-calls>
# - function_call
save_safetensors: true
### Possible solution
N/A
### Which Operating Systems are you using?
- [X] Linux
- [ ] macOS
- [ ] Windows
### Python Version
3.11
### axolotl branch-commit
main
### Acknowledgements
- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
winglian commented
I don't see a dataset in your configuration YAML. Did you redact it? Can you provide some info on the dataset/prompt type you're trying to preprocess?
ohmeow commented
UPDATE: Looks like its something with Transformers with the recommendation being to install from github directly. See: https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/discussions/54
Sorry, I'm adding that in dynamically (running in jupyter). I'm using the template free format which is working fine in old Llama3 ...
# axo_config_fpath = "configs/axolotl_configs/llama3-8b-qlora.yaml"
axo_config_fpath = "configs/axolotl_configs/llama3.1-8b-qlora.yaml"
train_data_fpath = "data/train_reviewed_template_free_1000.jsonl"
train_data_config = str(f'[{{"path": "{train_data_fpath}", "type":"input_output"}}]')
python -m axolotl.cli.preprocess {axo_config_fpath} --datasets '{train_data_config}'
ohmeow commented
Closing this out. Can verify that pip install from the transformers main branch provides the necessary fix.