请问下代码里的kl散度问题
rigorosyangffff opened this issue · 1 comments
rigorosyangffff commented
您好!
我看了下代码,发现里面的token级的reward里加的kl 惩罚好像不是按标准的kl散度计算的,标准的应该是按两个分布来计算。但是我看代码里好像用的是只用了label这个一个token的概率相除(标准的kl散度能保证是非零的,但是现在代码里的实现不是可能是一个负数么),这是为什么呢?还有我看approx_kl也是这样。
kxzxvbk commented
Same question. This lead to a strange situation. The final kl loss is computed like:
kl_penalty = -self.kl_penalty_weight * (logprobs - ref_logprob)
However, the part ref_logprob
does not require grad. So maybe it can be removed from computation graph. In current situation, the regularization is more similar to "limit the label logit and prevent it becoming too large" rather than a normal kl-divergence.