intel-analytics/ipex-llm

RuntimeError: "fused_dropout" not implemented for 'Byte' when running trl ppo finetuning

Opened this issue · 3 comments

Machine: MAX1100
ipex-llm: 2.1.0b20240421
bigdl-core-xe-21 2.5.0b20240421
bigdl-core-xe-esimd-21 2.5.0b20240421

Related PR
When trying to run trl PPO finetuning on MAX1100, I got the following error.

(ppo) (base) wangyishuo@7cc25526b7ac:~/ziteng$ python ppo.py --model_name "/mnt/disk1/Llama-2-7b-chat-hf" --dataset_name "HuggingFaceH4/helpful_instructions"
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: ''If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/datasets/load.py:1461: FutureWarning: The repository for HuggingFaceH4/helpful_instructions contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceH4/helpful_instructions
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
  warnings.warn(
2024-04-22 19:34:28,707 - root - INFO - intel_extension_for_pytorch auto imported
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.15s/it]
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
2024-04-22 19:34:31,311 - root - INFO - peft adapter initialised
2024-04-22 19:34:31,315 - ipex_llm.transformers.utils - INFO - Converting the current model to fp4 format......
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.00it/s]
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at /mnt/disk1/Llama-2-7b-chat-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
0it [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
0it [00:02, ?it/s]
Traceback (most recent call last):
  File "/home/wangyishuo/ziteng/ppo.py", line 248, in <module>
    response_tensors = ppo_trainer.generate(
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 469, in generate
    response = self._generate_batched(
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/trainer/ppo_trainer.py", line 556, in _generate_batched
    generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/trl/models/modeling_value_head.py", line 204, in generate
    return self.pretrained_model.generate(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/peft_model.py", line 1190, in generate
    outputs = self.base_model.generate(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/lookup.py", line 86, in generate
    return original_generate(self,
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/ipex_llm/transformers/speculative.py", line 103, in generate
    return original_generate(self,
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 1520, in generate
    return self.sample(
           ^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/generation/utils.py", line 2617, in sample
    outputs = self(
              ^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 386, in forward
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/peft/tuners/lora/layer.py", line 509, in forward
    result = result + lora_B(lora_A(dropout(x))) * scaling
                                    ^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/modules/dropout.py", line 58, in forward
    return F.dropout(input, self.p, self.training, self.inplace)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wangyishuo/miniconda3/envs/ppo/lib/python3.11/site-packages/torch/nn/functional.py", line 1266, in dropout
    return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
                                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: "fused_dropout" not implemented for 'Byte'

@leonardozcm pls take a look, whether it is not supported by our kernel? tks.

hi, I think the VF.drop is not implemented by our kernels, instead I suppose this error indicates that input is in 8-bit data format which is not a supported dtype for torch.nn.functional.dropout

@Jasonzzt From the log, it is found that PPO also applies PEFT LoRA.
Therefore, like QLoRA, rather than from_pretrained a peft model with lora config, we should first load the base model, and then use get_peft_model, prepare_model_for_kbit_training etc. methods in qlora.py to create a peft model. Such a model is built on top of layers with supported operators like here.