PEFT Prompt Tuning: forward() got an unexpected keyword argument 'inputs_embeds'
mikeleske opened this issue · 2 comments
mikeleske commented
Has someone successfully applied to Prompt Tuning PEFT to EVO?
With the HF SFT Trainer and the following PEFT config
peft_config = PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
num_virtual_tokens=128,
tokenizer_name_or_path='togethercomputer/evo-1-131k-base'
)
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()
I get the following error:
Traceback (most recent call last):
File "sft.py", line 227, in <module>
trainer.train()
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/trl/trainer/sft_trainer.py", line 361, in train
output = super().train(*args, **kwargs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/transformers/trainer.py", line 3161, in compute_loss
outputs = model(**inputs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/peft/peft_model.py", line 1177, in forward
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/mleske/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'inputs_embeds'
Any pointer on what I am doing wrong would be largely appreciated.