Hi @eric-mitchell ,
Gryff1ndor opened this issue · 3 comments
In your formula (the image below), it seems that the log[π(y|x)] was calculate through .sum(-1) after logits.softmax(-1), then .log().
But in your codes (the image below), the log[π(y|x)] was calculate through .sum(-1) after logits.log_softmax(-1).
the two ways to calculate log[π(y|x)] seem different.Could you please tell me if they conflict each other?
Originally posted by @Gryff1ndor in #57 (comment)
@Gryff1ndor, I think it might be because
And taking a log:
Which would correspond to the sum at the bottom of _get_batch_logps()
.
Thanks a lot! There is another problem which bothers me:
When using the DPO loss in my work, I found that the sigmoid function in DPOloss caused gradient explosion, because of sigmoid(x) turned out 0 or 1( when x tended to -∞ or +∞). But actually the value of x was just like -10 or +10.
Do you know how to fix it?
Happy to help! I'm learning this stuff as well, so take it with a grain of salt, but I think there's a couple of things you can do:
- Play around with the hyperparams. Try lowering the learning rate, or lowering beta.
- Gradient clipping. Maybe reducing the configured gradient norm limit:
self.config.max_grad_norm
, will prevent exploding gradients. - Perhaps another loss function like the 'conservative' DPO, or IPO will work in your case. I also wonder would would happen if you modified the DPO loss to use the clip function, something like PPO-clip, to keep the sigmoid in the desirable range.
Let me know how it goes! Right now I'm working with another fascinating repository that introduces a related loss called KTO. It's actually designed to be an extension of this repository. So it's very possible that the authors of the KTO repo addressed exploding gradient problem in their code. I will be training with DPO afterwards, so it'll be helpful to know what worked for you.