shaochenze/PatchTrain

关于交叉熵的具体计算细节

JizhanFang opened this issue · 4 comments

你好,我看了下源代码中关于计算logits和labels的交叉熵的部分,我的理解是,预测的时候,预测出来的表征是patch级别的,也就是说logits中第二个维度是patch数(shape of logits:[batch_size,patch_nums,vocab_size]),然后经过reshape之后,最后是和该patch对应的每个token的labels计算交叉熵,请问我的理解正确吗?换个问法就是:在next patch prediction阶段预测的时候,预测出来的每一个表征都是代表着一个patch的综合表征,而不涉及token-level的表征, 只是在计算交叉熵的时候会拿这个logits来与该patch对应的每个token的label进行计算,请问是这样吗?望尽快回复,感谢!

你好,你的理解完全正确。模型预测的表征都是 patch 级的,所得的 logits 会与 next patch 的所有 token 计算交叉熵损失。

好的,感谢回复。请问你们后续会利用这种预训练方法去尝试训练更大的模型吗?例如7B的模型。还有就是我想请教一下你们是怎么想到把表征直接求平均来压缩成patch的?

我们目前不会再训更大的模型了。我们将文本中的 patch 认定为多个 token 的混合,基于这个理念来设计 patch 模型,所以模型的输入直接是 token 表示的平均,模型输出的分布也和下个 patch 中所有 token 同时计算交叉熵损失,预测它们的混合概率。

好的,感谢您的回复