init_vectors为空, batch-size修改为其他值时程序报错
Isaiah1013 opened this issue · 1 comments
Isaiah1013 commented
作者你好,我使用一个20类的数据集跑程序,发现有三个小问题:
(1)发现init_vetors为空时,else 分支里的num_classes没有定义,具体位置在prompt_indicator.py文件的97行:
(2)在transform_pipelines.py文件的48和62行num_cats为固定值80,这应该随数据集类别数量而变化。
(3)在单块GPU上将batch-size改为其他值时,prompt_indicator.py的52行会报错(TyperError: only integer tensors of a single element can be converted to an index),导致问题出现的原因是这段程序使用的num_classes是一个二维向量(batch_size, 1)
对于问题(1)(2)是在config.py文件添加相应的地方添加一个num_classes变量,对于问题(3),我将程序段中num_classes的二维向量(batch_size, 1)改为(batch_size, )后就没问题了
作者如果看到,可以测试一下