RUCAIBox/TextBox

怎么运行 persona chat的数据集

thinkingmanyangyang opened this issue · 6 comments

请问作者,我如何用这份代码,复现所有表格中的数据呢,每个对应的数据集有没有对应的运行命令呢

python run_textbox.py --model=[model_name] --dataset=[dataset_name]
其中[model_name]在https://github.com/RUCAIBox/TextBox/tree/main/textbox/properties/model中,[dataset_name]在https://github.com/RUCAIBox/TextBox/tree/main/textbox/properties/dataset中。
数据集可以在百度网盘下载。

Woeee commented

请问下作者,在使用transformer运行personachat时,source_text的维度是3维,input和target的维度是二维,导致source_text在进行位置编码时报错,这个需要怎么解决,还有想问下,新版本里面dialog任务是不是和translation任务的dataloader之类的modules合并了,下载的原始的personachat.yaml文件中,任务类型就是translation

使用TransformerEncDec运行Persona_Chat时,请在命令行添加--src_multi_sent=False
任务类型我们正在合并统一,新版本有可能会取消

Woeee commented

我尝试了您提供的解决方法,但是好像没有解决这个问题,会在同样的地方报错
File "/home/kun_zhou/yuhui/nlp/model.py", line 185, in forward
source_embeddings = self.source_token_embedder(source_text) + self.position_embedder(source_text).to(self.device)
……
File "/home/kun_zhou/yuhui/nlp/modules.py", line 15, in forward
batch_size, seq_len = input_seq.size()
ValueError: too many values to unpack (expected 2)

输出source_text,input_text,target_text的shape,依次得到
torch.Size([64, 20, 23])
torch.Size([64, 21])
torch.Size([64, 21])

就是source_text的维度不匹配,应该是对source_text处理的PairedSentenceDataLoader出了一些bug
也可能我哪里理解的有一些问题……

你尝试把dataset/Persona_Chat/*.bin删除再试一试

Woeee commented

非常感谢,删除之后这块可以跑通了