Confusion about KL divergence calculation for human feedback policies
dwyzzy opened this issue · 13 comments
Hi, thanks for the great work.
I also have a question about KL divergence loss.
In papers like Learning to summarize from human feedback, the KL item for human feedback policies seems to be the KL divergence between
kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight
seems to be the KL divergence between
Does there exist something wrong with the code, or have I made some mistakes?
Thank you.
my understanding is that pi rl corresponds to the new action probs while pi sft corresponds to the old. let me know through a pull request what you think the necessary changes should be, if you believe different
@lucidrains i am curious why you use gumbel sample to collect(add gumbel noise and get argmax index) action for a given state(prompt)? why not Categorical?eg:dist = Categorical(logist),action = dist.sample()
pi[old] and pi are same model with different parameters in my view, however, pi[old] updates slower compared to pi。pi[old] is just used for important sampling in RL(eg:ppo)。In instruct-gpt,pi[old] and pi are SFT model with different model parameters。
@lucidrains i am curious why you use gumbel sample to collect(add gumbel noise and get argmax index) action for a given state(prompt)? why not Categorical?eg:dist = Categorical(logist),action = dist.sample()
it is equivalent
pi[old] and pi are same model with different parameters in my view, however, pi[old] updates slower compared to pi。pi[old] is just used for important sampling in RL(eg:ppo)。In instruct-gpt,pi[old] and pi are SFT model with different model parameters。
yea, and old action prob is sampled from pi[old] and new action prob from pi[rl]. feel free to correct if i'm mistakened. also provide (pseudo)code, as it would be clearer than english
Hi. I think the kl divergence for human feedback policies (i.e.
The final reward should be
@dwyzzy yea, i just noticed that on rereading
do you think it makes a difference whether it is subtracted from the rewards rather than just added as an auxiliary loss?
i am by no means knowledgeable with the RL field
@dwyzzy yea, i just noticed that on rereading
do you think it makes a difference whether it is subtracted from the rewards rather than just added as an auxiliary loss?
if there are any RL experts in the room, now is the time to shine
@lucidrains Hi. I agree that these two approaches are similar, where the kl divergence is used to keep the newest RL policy from deviating too much from the original SFT model. From my point of view, 0.2.0 is more closer to these RLHF papers (add the kl divergence penalty of SFT model and RL policy to the reward).
Thank you again for the great work!
@dwyzzy ok sounds good, i'm really curious what the difference is, if any
do email me if you end up trying both approaches
I think that is a quite interesting point. I believe in the original PPO rl algo, the kl divergence should be calculated between
Reference: https://spinningup.openai.com/en/latest/algorithms/ppo.html
However, in rlhf, it seems that the KL divergence is calculated between 1 instead of 2. Any idea why is that the case?
-
$\pi^{RL}$ and$\pi^{SFT}$ -
$\pi^{RL}_{k+1}$ and$\pi^{RL}_{k}$
Or does it mean that actually there are two different KL divergences. 1 is added to the reward directly. 2 is still there for the PPO update?
I think that is a quite interesting point. I believe in the original PPO rl algo, the kl divergence should be calculated between πk at iteraction k and πk+1. In other words, sample the current policy πk then update the policy to find πk+1.
Reference: https://spinningup.openai.com/en/latest/algorithms/ppo.html
However, in rlhf, it seems that the KL divergence is calculated between 1 instead of 2. Any idea why is that the case?
- πRL and πSFT
- πk+1RL and πkRL
Or does it mean that actually there are two different KL divergences. 1 is added to the reward directly. 2 is still there for the PPO update?
I think it's what you said at the end: there are two different KL divergences.
(1) The KL divergence between
(2) The KL divergence between
You are welcome to point out the mistakes in my comment if I have made some~