DPO Loss not converging for Encoder-Decoder Models
chansurgeplus opened this issue · 6 comments
Hey Eric,
I've been attempting to fine-tune one of the encoder-decoder models with necessary changes accompanied to the Trainer (such as setting the labels during the forward pass etc).
I got the SFT phase done and in compliance with the general pattern of decrease in the loss as follows.
Train Loss:
However, during the DPO phase, I get the following patterns, which seems abnormal.
Train Loss:
Any thoughts on this?
The modifications that I've done to the code are as follows:
def forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
chosen_logits = model(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['chosen_input_ids']).logits.to(torch.float32)
chosen_logps = _get_batch_logps(chosen_logits, batch['chosen_input_ids'], average_log_prob=False)
rejected_logits = model(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['rejected_input_ids']).logits.to(torch.float32)
rejected_logps = _get_batch_logps(rejected_logits, batch['rejected_input_ids'], average_log_prob=False)
return chosen_logps, rejected_logps
- (Should be unrelated to DPO since SFT worked fine) SFT Batch Metrics calculation:
policy_chosen_logits = self.policy(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['chosen_input_ids']).logits.to(torch.float32)
policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_input_ids'], average_log_prob=False)
- Note that I've removed the labels from the batch keys and simply retained the relevant targets in chosen and rejected keys of the batch. prompt key-value pairs are now used for sending the prompt.
def get_collate_fn(tokenizer) -> Callable[[List[Dict]], Dict[str, Union[List, torch.Tensor]]]:
def collate_fn(batch):
# first, pad everything to the same length
padded_batch = {}
for k in batch[0].keys():
if k.endswith('_input_ids') or k.endswith('_attention_mask'):
if 'prompt' in k: # adapted from https://stackoverflow.com/questions/73256206
to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
else:
to_pad = [torch.LongTensor(ex[k]) for ex in batch]
if k.endswith('_input_ids'):
padding_value = tokenizer.pad_token_id
elif k.endswith('_attention_mask'):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
if 'prompt' in k: # for the prompt, flip back so padding is on left side
padded_batch[k] = padded_batch[k].flip(dims=[1])
else:
padded_batch[k] = [ex[k] for ex in batch]
return padded_batch
return collate_fn
def tokenize_batch_element(prompt: str, chosen: str, rejected: str, truncation_mode: str, tokenizer, max_length: int, max_prompt_length: int) -> Dict:
chosen_tokens = tokenizer(chosen, add_special_tokens=False)
rejected_tokens = tokenizer(rejected, add_special_tokens=False)
prompt_tokens = tokenizer(prompt, add_special_tokens=False)
assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}"
assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}"
assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}"
chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
chosen_tokens['attention_mask'].append(1)
rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
rejected_tokens['attention_mask'].append(1)
# if prompt sequence is too long, truncate the prompt
if len(prompt_tokens['input_ids']) > max_length:
if truncation_mode == 'keep_start':
prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
elif truncation_mode == 'keep_end':
prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}
else:
raise ValueError(f'Unknown truncation mode: {truncation_mode}')
# if target sequence is too long, truncate that as well.
longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))
if longer_response_length > max_length:
chosen_tokens = {k: v[:max_length] for k, v in chosen_tokens.items()}
rejected_tokens = {k: v[:max_length] for k, v in rejected_tokens.items()}
batch = {}
batch['prompt'] = prompt
batch['chosen'] = prompt + chosen
batch['rejected'] = prompt + rejected
batch['chosen_response_only'] = chosen
batch['rejected_response_only'] = rejected
for k, toks in {'chosen': chosen_tokens, 'rejected': rejected_tokens, 'prompt': prompt_tokens}.items():
for type_key, tokens in toks.items():
if type_key == 'token_type_ids':
continue
batch[f'{k}_{type_key}'] = tokens
return batch
@chansurgeplus would you mind having a look at the PR for this on the TRL side: huggingface/trl#586
@chansurgeplus sorry for the delay here- got behind with ICML last week.
One thing to be wary of is that prompts are left-padded, since the current DPO code only uses the prompt_
tokens for generation. HuggingFace correctly handles left-padded inputs for generation, but not for training (at least for left to right models); for training the positional encodings can get broken.
Also note that the chosen_*
and rejected_*
tokens contain not just the response, but the prompt as well! So you probably want to delete the prompt from the beginning of these, if you're using them to train the decoder part of your model.
Hope this is helpful- let me know if it doesn't fix your problem.
kashif
@kashif Thank you for bringing this to my attention. After running through the code in TRL, the only differences between my code and theirs are,
- The use of
prepare_decoder_input_ids_from_labels
in the classDPODataCollatorWithPadding
here. My question is does this have to be enforced for all models? - For teacher-forcing, setting both
labels
anddecoder_input_ids
parameters during the forward pass inDPOTrainer
here. Once again, isn't it sufficient to provide only the labels to the model for a forward pass as described here. I remember reading "decoder_input_ids
are automatically calculated usinglabels
". - I'm not using concatenated forward pass, instead doing two forward passes for each chosen and rejected responses. This should be according to this and docs.
- IMPORTANT: I see that in the method
_get_batch_logps
avoids the following changes tologits
andlabels
when training an encoder-decoder model, which I do not do in my implementation. Any specific reasons?
if not self.is_encoder_decoder:
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
According to how I read, they too do not have prompt embedded with the chosen and rejected responses, so that the trained model is supposed to produce the response only. If so, I think what I've done so far is identical to what TRL have in the PR except for the aforementioned differences.
@chansurgeplus sorry for the delay here- got behind with ICML last week.
One thing to be wary of is that prompts are left-padded, since the current DPO code only uses the
prompt_
tokens for generation. HuggingFace correctly handles left-padded inputs for generation, but not for training (at least for left to right models); for training the positional encodings can get broken.Also note that the
chosen_*
andrejected_*
tokens contain not just the response, but the prompt as well! So you probably want to delete the prompt from the beginning of these, if you're using them to train the decoder part of your model.Hope this is helpful- let me know if it doesn't fix your problem.
No worries @eric-mitchell. Hope you had a great time and presented your contributions at ICML. Keep inspiring!
Regarding the use of prompt_
tokens, with the changes that I've made to the code here, I think prompt
tokens should be good for use in the trainer, unless as you've specified tokenizer
in HuggingFace does something stupid. Did you imply that the tokenizer
is adding padding itself? (Apologies in advance if this is a stupid question, just wanted to be sure.
For chosen_*
and rejected_*
tokens, I've eliminated the code that prepends the prompt_tokens
and -100
token to them, as you can see in the code for tokenize_batch_element
method I published above