eric-mitchell/direct-preference-optimization

Understanding loss

monksgoyal opened this issue · 6 comments

I tried training using DPO.
My loss seems like this i.e. train loss

Step 500
{'loss': 0.3441, 'learning_rate': 0.0004605091805714775, 'rewards/chosen': -2.3275821208953857, 'rewards/rejected': -9.353358268737793, 'rewards/accuracies': 0.8579999804496765, 'rewards/margins': 7.02577543258667, 'logps/rejected': -130.95912170410156, 'logps/chosen': -35.571929931640625, 'logits/rejected': -0.9097328782081604, 'logits/chosen': -0.9106643199920654, 'epoch': 0.18}

Step 1000
{'loss': 0.3073, 'learning_rate': 0.00041412581074747143, 'rewards/chosen': -2.4493091106414795, 'rewards/rejected': -10.947866439819336, 'rewards/accuracies': 0.8774999976158142, 'rewards/margins': 8.498558044433594, 'logps/rejected': -147.00184631347656, 'logps/chosen': -36.60841751098633, 'logits/rejected': -0.9469738006591797, 'logits/chosen': -0.9475268721580505, 'epoch': 0.27}

Although loss is decreasing but rewards/chosen is decreasing which is should increase by the way.
rewards/rejected seems ok.

Is model trying to lower the probability of both chosen and rejected but rejected is getting lower more hence loss is decreasing.

Is my understanding correct?

Also i predicted my model on these checkpoints seems not good that reference model.

Above were training logs.
Eval logs also looks the same.
Rewards chosen is decreasing.

"log_history": [
{
"epoch": 0.09,
"eval_logits/chosen": -0.8852691054344177,
"eval_logits/rejected": -0.8777562379837036,
"eval_logps/chosen": -28.69784927368164,
"eval_logps/rejected": -106.26754760742188,
"eval_loss": 0.2819069027900696,
"eval_rewards/accuracies": 0.86328125,
"eval_rewards/chosen": -1.6413846015930176,
"eval_rewards/margins": 5.6007561683654785,
"eval_rewards/rejected": -7.2421417236328125,
"eval_runtime": 204.6538,
"eval_samples_per_second": 1.251,
"eval_steps_per_second": 0.313,
"step": 500
}
{
"epoch": 0.18,
"eval_logits/chosen": -0.7912834286689758,
"eval_logits/rejected": -0.7919993996620178,
"eval_logps/chosen": -27.437620162963867,
"eval_logps/rejected": -122.26292419433594,
"eval_loss": 0.2211304008960724,
"eval_rewards/accuracies": 0.87109375,
"eval_rewards/chosen": -1.5153619050979614,
"eval_rewards/margins": 7.326316833496094,
"eval_rewards/rejected": -8.84167766571045,
"eval_runtime": 208.8261,
"eval_samples_per_second": 1.226,
"eval_steps_per_second": 0.306,
"step": 1000
}
{
"epoch": 0.27,
"eval_logits/chosen": -0.5182173252105713,
"eval_logits/rejected": -0.5083359479904175,
"eval_logps/chosen": -30.679859161376953,
"eval_logps/rejected": -128.99560546875,
"eval_loss": 0.2379799485206604,
"eval_rewards/accuracies": 0.87890625,
"eval_rewards/chosen": -1.8395859003067017,
"eval_rewards/margins": 7.675361156463623,
"eval_rewards/rejected": -9.514947891235352,
"eval_runtime": 199.2648,
"eval_samples_per_second": 1.285,
"eval_steps_per_second": 0.321,
"step": 1500
}
{
"epoch": 0.36,
"eval_logits/chosen": -0.4966026544570923,
"eval_logits/rejected": -0.4896504580974579,
"eval_logps/chosen": -25.403793334960938,
"eval_logps/rejected": -129.7777862548828,
"eval_loss": 0.28067469596862793,
"eval_rewards/accuracies": 0.8515625,
"eval_rewards/chosen": -1.311979055404663,
"eval_rewards/margins": 8.281187057495117,
"eval_rewards/rejected": -9.59316635131836,
"eval_runtime": 199.9027,
"eval_samples_per_second": 1.281,
"eval_steps_per_second": 0.32,
"step": 2000
}
{
"epoch": 0.45,
"eval_logits/chosen": -1.0843309164047241,
"eval_logits/rejected": -1.0715372562408447,
"eval_logps/chosen": -26.54033660888672,
"eval_logps/rejected": -144.78355407714844,
"eval_loss": 0.22287404537200928,
"eval_rewards/accuracies": 0.89453125,
"eval_rewards/chosen": -1.425633430480957,
"eval_rewards/margins": 9.668107032775879,
"eval_rewards/rejected": -11.093740463256836,
"eval_runtime": 197.8472,
"eval_samples_per_second": 1.294,
"eval_steps_per_second": 0.323,
"step": 2500
}
{
"epoch": 0.54,
"eval_logits/chosen": -1.0749739408493042,
"eval_logits/rejected": -1.0576728582382202,
"eval_logps/chosen": -36.400691986083984,
"eval_logps/rejected": -159.21218872070312,
"eval_loss": 0.21191225945949554,
"eval_rewards/accuracies": 0.921875,
"eval_rewards/chosen": -2.4116692543029785,
"eval_rewards/margins": 10.124935150146484,
"eval_rewards/rejected": -12.536603927612305,
"eval_runtime": 199.0354,
"eval_samples_per_second": 1.286,
"eval_steps_per_second": 0.322,
"step": 3000
}
{
"epoch": 0.63,
"eval_logits/chosen": -0.7880315184593201,
"eval_logits/rejected": -0.7736120820045471,
"eval_logps/chosen": -36.00906753540039,
"eval_logps/rejected": -160.41616821289062,
"eval_loss": 0.1770719289779663,
"eval_rewards/accuracies": 0.9296875,
"eval_rewards/chosen": -2.372506618499756,
"eval_rewards/margins": 10.28449821472168,
"eval_rewards/rejected": -12.657003402709961,
"eval_runtime": 198.7832,
"eval_samples_per_second": 1.288,
"eval_steps_per_second": 0.322,
"step": 3500
}
{
"epoch": 0.73,
"eval_logits/chosen": -0.9893665313720703,
"eval_logits/rejected": -0.9705553650856018,
"eval_logps/chosen": -33.47208023071289,
"eval_logps/rejected": -176.95199584960938,
"eval_loss": 0.1922326683998108,
"eval_rewards/accuracies": 0.921875,
"eval_rewards/chosen": -2.118807792663574,
"eval_rewards/margins": 12.191776275634766,
"eval_rewards/rejected": -14.31058406829834,
"eval_runtime": 198.8575,
"eval_samples_per_second": 1.287,
"eval_steps_per_second": 0.322,
"step": 4000
}
{
"epoch": 0.82,
"eval_logits/chosen": -0.9812024831771851,
"eval_logits/rejected": -0.9520618915557861,
"eval_logps/chosen": -37.32331466674805,
"eval_logps/rejected": -198.68295288085938,
"eval_loss": 0.22656424343585968,
"eval_rewards/accuracies": 0.9296875,
"eval_rewards/chosen": -2.5039308071136475,
"eval_rewards/margins": 13.979748725891113,
"eval_rewards/rejected": -16.483680725097656,
"eval_runtime": 198.639,
"eval_samples_per_second": 1.289,
"eval_steps_per_second": 0.322,
"step": 4500
}
{
"epoch": 0.91,
"eval_logits/chosen": -0.8927727937698364,
"eval_logits/rejected": -0.8585475087165833,
"eval_logps/chosen": -35.953147888183594,
"eval_logps/rejected": -197.0956573486328,
"eval_loss": 0.21030420064926147,
"eval_rewards/accuracies": 0.921875,
"eval_rewards/chosen": -2.366915225982666,
"eval_rewards/margins": 13.958037376403809,
"eval_rewards/rejected": -16.324951171875,
"eval_runtime": 198.2515,
"eval_samples_per_second": 1.291,
"eval_steps_per_second": 0.323,
"step": 5000
}
{
"epoch": 1.0,
"eval_logits/chosen": -0.9086193442344666,
"eval_logits/rejected": -0.875652015209198,
"eval_logps/chosen": -34.556678771972656,
"eval_logps/rejected": -191.85987854003906,
"eval_loss": 0.2012159377336502,
"eval_rewards/accuracies": 0.9296875,
"eval_rewards/chosen": -2.2272677421569824,
"eval_rewards/margins": 13.574106216430664,
"eval_rewards/rejected": -15.801374435424805,
"eval_runtime": 200.2405,
"eval_samples_per_second": 1.278,
"eval_steps_per_second": 0.32,
"step": 5500
}
]

Did you do an SFT stage on your chosen responses before running DPO? TL;DR is that this behavior is not unexpected.

Like you've pointed out, DPO optimizes for the reward margin between chosen and rejected, not the absolute rewards themselves. Because doing SFT with maximum likelihood on the chosen responses gives us roughly a model that assigns the highest probability possible to the chosen responses, fine-tuning with any other objective is likely to lower the probability assigned to chosen.

Do you mean that you evaluated the DPO policy and it was worse than the reference model? Can you share some samples? What beta did you use?

Information about model and training:
Task: Question and Answering from context of documents.
Architecture: Llama-2-7B-Chat
Finetuning: Standard Lora finetuning on DPO loss.
Beta was 0.1.

Yes SFT was done on same distribution data but not on the same data used in DPO training i.e. some portion of data was reserved for SFT training (70%) and rest 30% for DPO training.
Also what is not working better than reference model compared to DPO trained model are the cases where model should not have responded (I dont know cases)

I have one more observation:
∇θLDPO(πθ; πref) which was mentioned in the paper.
I can see loss will go in the right direction of increasing chosen probabilities and decreasing rejected probabilities if initial chosen probabilities are lower than rejected probabilities on data on which we are doing DPO training.

But this case might not hold true if chosen probabilities are greater than rejected probabilities initially.
Again just want to confirm my observation according to the equation.
So i was thinking to drop rejected term from Loss.

i.e. equation 7 in the paper we can drop the rejected reward term completely.

@eric-mitchell let me know if my understanding is correct.

Hi @monksgoyal , I've observed a similar pattern of losses to yours during the DPO phase in my own experiment. So far, one thing I'm certain about is that the objective function of DPO, i.e. Eq.7, can be satisfied if the decrease of rejected responses' log-probs are greater than the decrease of chosen responses' log-probs. In the provided wandb log, the log-probs of chosen and rejected responses are indeed both decreasing. Thus Eq.7 is satisfied in a way that the decrease on the rejected is higher than the decrease on the chosen. What do you think?

BTW, did you manage to run any following experiment to further understand the loss of DPO?

Because doing SFT with maximum likelihood on the chosen responses gives us roughly a model that assigns the highest probability possible to the chosen responses, fine-tuning with any other objective is likely to lower the probability assigned to chosen.

Hi @eric-mitchell , I really doubt the above claim. I've plotted the curves of log-prob of chosen and rejected responses, and they are both increased during SFT. More importantly, on more than half of the responses, the log-prob of the rejected is higher than the log-prob of the chosen. So, basically, after SFT, the model doesn't assign the highest probability to the chosen responses.

From my perspective, what the model learns during SFT is more likely to shift its distribution closer to the generative distribution of the preference data.