What exactly is "concatenated_inputs" doing?
Closed this issue · 1 comments
I am scanning the code to implement DPO for my VQA model. I comment in get_batch_samples said "We do this to avoid doing two forward passes, because it's faster for FSDP.". If my model accepts image tensors and text captions for the forward, how should I get the concatenated_inputs? What exactly is "concatenated_inputs" doing?
`def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
print("batch")
print(batch)
max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
concatenated_batch = {}
for k in batch:
if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
pad_value = -100 if 'labels' in k else 0
concatenated_key = k.replace('chosen', 'concatenated')
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
pad_value = -100 if 'labels' in k else 0
concatenated_key = k.replace('rejected', 'concatenated')
concatenated_batch[concatenated_key] = torch.cat((
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
), dim=0)
print("CONCATENATED")
print(concatenated_batch)
return concatenated_batch`
concatenated_inputs
just combines the batch of chosen and rejected sequences because we need to compute policy logits for both, and doing a single forward pass seemed a bit faster to me with FSDP. So this function just concatenates the input ids/attention mask/labels for chosen and rejected.