Calculating the kl loss seems has a mistake.
Nightbringers opened this issue · 1 comments
Nightbringers commented
code:
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight
I think old_action_probs should be y(true), action_probs should be y(pred),i think the right code should be this:
kl_div_loss = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight
Am I right?or Im misunderstanding.
lucidrains commented
no i think you may be correct, will make the change! 🙏