eric-mitchell/direct-preference-optimization

DPO Loss not converging for Encoder-Decoder Models

chansurgeplus opened this issue · 6 comments

Hey Eric,

I've been attempting to fine-tune one of the encoder-decoder models with necessary changes accompanied to the Trainer (such as setting the labels during the forward pass etc).

I got the SFT phase done and in compliance with the general pattern of decrease in the loss as follows.
Train Loss:
W B Chart 28_07_2023, 22_10_57

Eval loss for SFT:
W B Chart 28_07_2023, 22_11_03

However, during the DPO phase, I get the following patterns, which seems abnormal.
Train Loss:
W B Chart 28_07_2023, 22_08_10

Grad Norm:
W B Chart 28_07_2023, 22_08_18

Any thoughts on this?

The modifications that I've done to the code are as follows:

def forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    chosen_logits = model(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['chosen_input_ids']).logits.to(torch.float32)
    chosen_logps = _get_batch_logps(chosen_logits, batch['chosen_input_ids'], average_log_prob=False)

    rejected_logits = model(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['rejected_input_ids']).logits.to(torch.float32)
    rejected_logps = _get_batch_logps(rejected_logits, batch['rejected_input_ids'], average_log_prob=False)

    return chosen_logps, rejected_logps
  1. (Should be unrelated to DPO since SFT worked fine) SFT Batch Metrics calculation:
policy_chosen_logits = self.policy(batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], labels=batch['chosen_input_ids']).logits.to(torch.float32)
            policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_input_ids'], average_log_prob=False)
  1. Note that I've removed the labels from the batch keys and simply retained the relevant targets in chosen and rejected keys of the batch. prompt key-value pairs are now used for sending the prompt.
def get_collate_fn(tokenizer) -> Callable[[List[Dict]], Dict[str, Union[List, torch.Tensor]]]:
    def collate_fn(batch):
        # first, pad everything to the same length
        padded_batch = {}
        for k in batch[0].keys():
            if k.endswith('_input_ids') or k.endswith('_attention_mask'):
                if 'prompt' in k:  # adapted from https://stackoverflow.com/questions/73256206
                    to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch]
                else:
                    to_pad = [torch.LongTensor(ex[k]) for ex in batch]
                if k.endswith('_input_ids'):
                    padding_value = tokenizer.pad_token_id
                elif k.endswith('_attention_mask'):
                    padding_value = 0
                else:
                    raise ValueError(f"Unexpected key in batch '{k}'")

                padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value)
                if 'prompt' in k:  # for the prompt, flip back so padding is on left side
                    padded_batch[k] = padded_batch[k].flip(dims=[1])
            else:
                padded_batch[k] = [ex[k] for ex in batch]

        return padded_batch
    return collate_fn


def tokenize_batch_element(prompt: str, chosen: str, rejected: str, truncation_mode: str, tokenizer, max_length: int, max_prompt_length: int) -> Dict:
    chosen_tokens = tokenizer(chosen, add_special_tokens=False)
    rejected_tokens = tokenizer(rejected, add_special_tokens=False)
    prompt_tokens = tokenizer(prompt, add_special_tokens=False)

    assert tokenizer.eos_token_id not in prompt_tokens['input_ids'], f"Prompt contains EOS token: {prompt}"
    assert tokenizer.eos_token_id not in chosen_tokens['input_ids'], f"Chosen response contains EOS token: {chosen}"
    assert tokenizer.eos_token_id not in rejected_tokens['input_ids'], f"Rejected response contains EOS token: {rejected}"

    chosen_tokens['input_ids'].append(tokenizer.eos_token_id)
    chosen_tokens['attention_mask'].append(1)

    rejected_tokens['input_ids'].append(tokenizer.eos_token_id)
    rejected_tokens['attention_mask'].append(1)

    # if prompt sequence is too long, truncate the prompt
    if len(prompt_tokens['input_ids']) > max_length:
        if truncation_mode == 'keep_start':
            prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()}
        elif truncation_mode == 'keep_end':
            prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()}
        else:
            raise ValueError(f'Unknown truncation mode: {truncation_mode}')

    # if target sequence is too long, truncate that as well.
    longer_response_length = max(len(chosen_tokens['input_ids']), len(rejected_tokens['input_ids']))
    if longer_response_length > max_length:
        chosen_tokens = {k: v[:max_length] for k, v in chosen_tokens.items()}
        rejected_tokens = {k: v[:max_length] for k, v in rejected_tokens.items()}

    batch = {}

    batch['prompt'] = prompt
    batch['chosen'] = prompt + chosen
    batch['rejected'] = prompt + rejected
    batch['chosen_response_only'] = chosen
    batch['rejected_response_only'] = rejected

    for k, toks in {'chosen': chosen_tokens, 'rejected': rejected_tokens, 'prompt': prompt_tokens}.items():
        for type_key, tokens in toks.items():
            if type_key == 'token_type_ids':
                continue
            batch[f'{k}_{type_key}'] = tokens

    return batch

Here's a guarantee that this is blowing up the eval/loss as well (after 3k steps)

W B Chart 28_07_2023, 22_26_41

kashif commented

@chansurgeplus would you mind having a look at the PR for this on the TRL side: huggingface/trl#586

@chansurgeplus sorry for the delay here- got behind with ICML last week.

One thing to be wary of is that prompts are left-padded, since the current DPO code only uses the prompt_ tokens for generation. HuggingFace correctly handles left-padded inputs for generation, but not for training (at least for left to right models); for training the positional encodings can get broken.

Also note that the chosen_* and rejected_* tokens contain not just the response, but the prompt as well! So you probably want to delete the prompt from the beginning of these, if you're using them to train the decoder part of your model.

Hope this is helpful- let me know if it doesn't fix your problem.

kashif

@kashif Thank you for bringing this to my attention. After running through the code in TRL, the only differences between my code and theirs are,

  1. The use of prepare_decoder_input_ids_from_labels in the class DPODataCollatorWithPadding here. My question is does this have to be enforced for all models?
  2. For teacher-forcing, setting both labels and decoder_input_ids parameters during the forward pass in DPOTrainer here. Once again, isn't it sufficient to provide only the labels to the model for a forward pass as described here. I remember reading "decoder_input_ids are automatically calculated using labels".
  3. I'm not using concatenated forward pass, instead doing two forward passes for each chosen and rejected responses. This should be according to this and docs.
  4. IMPORTANT: I see that in the method _get_batch_logps avoids the following changes to logits and labels when training an encoder-decoder model, which I do not do in my implementation. Any specific reasons?
if not self.is_encoder_decoder:
  labels = labels[:, 1:].clone()
  logits = logits[:, :-1, :]

According to how I read, they too do not have prompt embedded with the chosen and rejected responses, so that the trained model is supposed to produce the response only. If so, I think what I've done so far is identical to what TRL have in the PR except for the aforementioned differences.

@chansurgeplus sorry for the delay here- got behind with ICML last week.

One thing to be wary of is that prompts are left-padded, since the current DPO code only uses the prompt_ tokens for generation. HuggingFace correctly handles left-padded inputs for generation, but not for training (at least for left to right models); for training the positional encodings can get broken.

Also note that the chosen_* and rejected_* tokens contain not just the response, but the prompt as well! So you probably want to delete the prompt from the beginning of these, if you're using them to train the decoder part of your model.

Hope this is helpful- let me know if it doesn't fix your problem.

No worries @eric-mitchell. Hope you had a great time and presented your contributions at ICML. Keep inspiring!

Regarding the use of prompt_ tokens, with the changes that I've made to the code here, I think prompt tokens should be good for use in the trainer, unless as you've specified tokenizer in HuggingFace does something stupid. Did you imply that the tokenizer is adding padding itself? (Apologies in advance if this is a stupid question, just wanted to be sure.

For chosen_* and rejected_* tokens, I've eliminated the code that prepends the prompt_tokens and -100 token to them, as you can see in the code for tokenize_batch_element method I published above