CASIA-IVA-Lab/Obj2Seq

init_vectors为空, batch-size修改为其他值时程序报错

Isaiah1013 opened this issue · 1 comments

作者你好,我使用一个20类的数据集跑程序,发现有三个小问题:
(1)发现init_vetors为空时,else 分支里的num_classes没有定义,具体位置在prompt_indicator.py文件的97行:
capture_20221010204019568

(2)在transform_pipelines.py文件的48和62行num_cats为固定值80,这应该随数据集类别数量而变化。

(3)在单块GPU上将batch-size改为其他值时,prompt_indicator.py的52行会报错(TyperError: only integer tensors of a single element can be converted to an index),导致问题出现的原因是这段程序使用的num_classes是一个二维向量(batch_size, 1)

对于问题(1)(2)是在config.py文件添加相应的地方添加一个num_classes变量,对于问题(3),我将程序段中num_classes的二维向量(batch_size, 1)改为(batch_size, )后就没问题了

作者如果看到,可以测试一下

非常感谢您对我们工作的关注,针对您提出的3处代码问题,我们已经在最近一次提交中进行了处理,您可以在提交详情中查看具体修改内容。