RUCAIBox/TextBox

[🐛BUG] Context Tuning bug

Closed this issue · 1 comments

描述这个 bug
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 744 but got size 768 for tensor number 1 in the list.

如何复现
gyafc_em.yaml里的train_batch_size设为16 (用的3090 24G, 设64的话报错说显存不够)
运行python run_textbox.py --model=Context_Tuning --dataset=gyafc_em

日志
(15 Apr 16:30 ERROR Traceback (most recent call last):
File "/root/TextBox/textbox/utils/dashboard.py", line 311, in new_experiment
yield True
File "/root/TextBox/textbox/quick_start/experiment.py", line 138, in run
self._do_train_and_valid()
File "/root/TextBox/textbox/quick_start/experiment.py", line 113, in _do_train_and_valid
self.valid_result = self.trainer.fit(train_data, valid_data)
File "/root/TextBox/textbox/trainer/trainer.py", line 451, in fit
loss = self._train_epoch(train_data, epoch_idx, valid_data)['loss']
File "/root/TextBox/textbox/trainer/trainer.py", line 221, in _train_epoch
loss = self.model(data, epoch_idx=epoch_idx)
File "/usr/local/miniconda3/envs/TextBox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/root/TextBox/textbox/model/abstract_model.py", line 69, in forward
inputs = self._process_prompt_tuning_input(inputs, batch)
File "/root/TextBox/textbox/model/context_tuning.py", line 88, in _process_prompt_tuning_input
inputs_embeds = torch.cat([prompt_embeds[:, 0], inputs_embeds, prompt_embeds[:, 1]], dim=1) # b, pl+l+pl, e
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 744 but got size 768 for tensor number 1 in the list.

我们已经在最新的pr中解决了这个问题,麻烦pull最新的仓库。非常感谢您的报告。