关于代码的一些问题
Closed this issue · 5 comments
①seqxgpt的train函数中,保存checkpoint时是不是应该加一个判断当前test的acc和最佳acc的语句呀,现在的代码按我的理解好像是保存了最后一个epoch的模型?不知道我的理解对不对。
②seqxgpt的train函数中,每个epoch结束都有save,为什么所有epoch结束又有这三行语句呢:
torch.save(self.model.cpu(), save_dir)
saved_model = torch.load(save_dir)
self.model.load_state_dict(saved_model.state_dict())
这里我没有太理解QAQ
问题1:是的,你说的完全正确。我这里并没有过度调参,更多的是通过打印的loss判断是不是收敛了,并没有很关注是不是拿到了最优的checkpoint。事实上只要loss收敛了,即使不是dev最优的,也跟最优的差别很小的。
问题2:你说的完全正确!这里的设计原本是这样的:比如我每个epoch会保存一个,每个epoch存的位置都不在一起;最后一个的模型也存储在last这个文件夹中;特别地,self.model.load_state_dict()是为了保证存储没问题,能从这里load到最新的模型,避免万一中间保存失败。当然这里的你也可以删去。
明白!感谢回答!
请问在train函数的test过程中,有个别text遇到了empty sent label list的情况,这是正常的吗
啊不好意思看到的晚了,我想了想觉得应该是正常的。因为有的句子就一个词或者一个标点,按照GPT2这种,因为它的tokenizer没有bos,所以截断下来就相当于没有logits list了,所以也就成了空的特征序列,自然没法预测了。
这个也是make sense的,因为其实我们也很清楚,一个句子一两个单词或者符号这种,根本没法判断是谁生成的。希望解答到了您的问题~
感谢!