LlamaFamily/Llama-Chinese

RuntimeError: FlashAttention only supports Ampere GPUs or newer.

540627735 opened this issue · 1 comments

kaggle上运行lora微调的代码出现这样的报错:Traceback (most recent call last):
File "/kaggle/working/llamatest/train/sft/finetune_clm_lora.py", line 692, in
main()
File "/kaggle/working/llamatest/train/sft/finetune_clm_lora.py", line 653, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2902, in training_step
loss = self.compute_loss(model, inputs)
File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2925, in compute_loss
outputs = model(**inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1852, in forward
loss = self.module(*inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/peft_model.py", line 1083, in forward
return self.base_model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1168, in forward
outputs = self.model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 997, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/opt/conda/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
outputs = run_function(*args)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 487, in forward
attn_output = self._flash_attention_forward(
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 537, in _flash_attention_forward
attn_output_unpad = flash_attn_varlen_func(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 1059, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(args, kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 576, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
File "/opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 85, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only supports Ampere GPUs or newer.
Exception raised from mha_varlen_fwd at /home/runner/work/flash-attention/flash-attention/csrc/flash_attn/flash_api.cpp:519 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0x6c (0x7f30afa8051c in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const
, char const
, unsigned int, char const
) + 0x84 (0x7f30afa35b04 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: mha_varlen_fwd(at::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor&, at::Tensor const&, at::Tensor const&, c10::optionalat::Tensor&, c10::optionalat::Tensor&, int, int, float, float, bool, bool, int, int, bool, c10::optionalat::Generator) + 0x10d7 (0x7f2fff7a46d7 in /opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)

Wszl commented

启动脚本里面把use_flash_attn改成false就行了