Matching the response tokens
vwxyzjn opened this issue · 1 comments
vwxyzjn commented
Hello, thanks for the nice reference code! I noticed the following code tries to match the response tokens, but it might match the instruction tokens instead
Lines 60 to 63 in aaa0ecb
This is because it breaks when the first token matches, but '### Response:\n'
is encoded with [21017, 18261, 25, 198]
., but it matches ### Instruction:\n
([21017, 46486, 25, 198]
) instead.
To resolve the issue and if it is indeed that you intended to match the response tokens, you should consider the following snippet instead :)
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
# `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
if response_token_ids == examples[i]["input_ids"][idx:idx+len(response_token_ids)]:
response_token_ids_start_idx = idx
Our related issue huggingface/trl#445 (comment)
srowen commented
CC @matthayes , WDYT?