RuntimeError: "fused_dropout" not implemented for 'Byte' when running trl ppo finetuning
Opened this issue · 3 comments
Machine: MAX1100
ipex-llm: 2.1.0b20240421
bigdl-core-xe-21 2.5.0b20240421
bigdl-core-xe-esimd-21 2.5.0b20240421
Related PR
When trying to run trl PPO finetuning on MAX1100, I got the following error.
(ppo) (base) wangyishuo@7cc25526b7ac:~/ziteng$ python ppo.py --model_name "/mnt/disk1/Llama-2-7b-chat-hf" --dataset_name "HuggingFaceH4/helpful_instructions"
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/datasets/load.py:1461: FutureWarning: The repository for HuggingFaceH4/helpful_instructions contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceH4/helpful_instructions
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
warnings.warn(
2024-04-22 19:34:28,707 - root - INFO - intel_extension_for_pytorch auto imported
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.15s/it]
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
warnings.warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
warnings.warn(
2024-04-22 19:34:31,311 - root - INFO - peft adapter initialised
2024-04-22 19:34:31,315 - ipex_llm.transformers.utils - INFO - Converting the current model to fp4 format......
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.00it/s]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at /mnt/disk1/Llama-2-7b-chat-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
0it [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
0it [00:02, ?it/s]
Traceback (most recent call last):
File "/home/wangyishuo/ziteng/ppo.py", line 248, in <module>
response_tensors = ppo_trainer.generate(
^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 469, in generate
response = self._generate_batched(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 556, in _generate_batched
generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 204, in generate
return self.pretrained_model.generate(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/peft_model.py", line 1190, in generate
outputs = self.base_model.generate(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/lookup.py", line 86, in generate
return original_generate(self,
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/speculative.py", line 103, in generate
return original_generate(self,
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 1520, in generate
return self.sample(
^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 2617, in sample
outputs = self(
^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 386, in forward
query_states = self.q_proj(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/tuners/lora/layer.py", line 509, in forward
result = result + lora_B(lora_A(dropout(x))) * scaling
^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/dropout.py", line 58, in forward
return F.dropout(input, self.p, self.training, self.inplace)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/functional.py", line 1266, in dropout
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "fused_dropout" not implemented for 'Byte'
@leonardozcm pls take a look, whether it is not supported by our kernel? tks.
hi, I think the VF.drop
is not implemented by our kernels, instead I suppose this error indicates that input
is in 8-bit data format which is not a supported dtype for torch.nn.functional.dropout
@Jasonzzt From the log, it is found that PPO also applies PEFT LoRA.
Therefore, like QLoRA, rather than from_pretrained
a peft model with lora config, we should first load the base model, and then use get_peft_model
, prepare_model_for_kbit_training
etc. methods in qlora.py to create a peft model. Such a model is built on top of layers with supported operators like here.