Computing faster lopgs
alexvishnevskiy opened this issue · 3 comments
Hi, great work! The results and research in this area are truly amazing. I have a question regarding the concatenated_forward part. From my understanding, we just need logs from both chosen and rejected responses. Why can't we have a batch that consists of [prompt + chosen_response + rejected_response] instead of [prompt + chosen_response, prompt + rejected_response]? It should be okay to calculate logps for both chosen and rejected responses without them intersecting with each other, using an attention mask. Correct me if I'm wrong, thanks!
def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = concatenated_inputs(batch)
all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]
return chosen_logps, rejected_logps
Also, how do you assure that when doing model fwd step, the prompt+rejected do not attend to the chosen response, at what place in code is this check made?
I think you are misunderstanding the implementation. They are concatenated in the batch dimension in order to get the logps for both in one forward pass instead of two. They are not concatenated in the sequence dimension so they will not attend to each other.
yes, i did not carefully read concat on dim=0. thank you for pointing out.