The way to calculate the log_probs
laurenlong opened this issue · 0 comments
laurenlong commented
Hi,
As the output of the model in each token's position represents the possibilities of next token, should the calculation of log_probs be misaligned.
I mean "diff_logits[range(diff_logits.shape[0]-1), continue_ids[1:]].sum().item()"
instead of "log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()".