lucidrains/PaLM-rlhf-pytorch

Calculating the kl loss seems has a mistake.

Nightbringers opened this issue · 1 comments

code:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight

I think old_action_probs should be y(true), action_probs should be y(pred),i think the right code should be this:
kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight

Am I right?or Im misunderstanding.

no i think you may be correct, will make the change! 🙏