eric-mitchell/direct-preference-optimization

What exactly is "concatenated_inputs" doing?

Closed this issue · 1 comments

I am scanning the code to implement DPO for my VQA model. I comment in get_batch_samples said "We do this to avoid doing two forward passes, because it's faster for FSDP.". If my model accepts image tensors and text captions for the forward, how should I get the concatenated_inputs? What exactly is "concatenated_inputs" doing?

`def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.

Args:
    batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
    
Returns:
    A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
print("batch")
print(batch)
max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
concatenated_batch = {}
for k in batch:
    if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
        pad_value = -100 if 'labels' in k else 0
        concatenated_key = k.replace('chosen', 'concatenated')
        concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
    if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
        pad_value = -100 if 'labels' in k else 0
        concatenated_key = k.replace('rejected', 'concatenated')
        concatenated_batch[concatenated_key] = torch.cat((
            concatenated_batch[concatenated_key],
            pad_to_length(batch[k], max_length, pad_value=pad_value),
        ), dim=0)
print("CONCATENATED")
print(concatenated_batch)
return concatenated_batch`

concatenated_inputs just combines the batch of chosen and rejected sequences because we need to compute policy logits for both, and doing a single forward pass seemed a bit faster to me with FSDP. So this function just concatenates the input ids/attention mask/labels for chosen and rejected.