databrickslabs/dolly

Matching the response tokens

vwxyzjn opened this issue · 1 comments

Hello, thanks for the nice reference code! I noticed the following code tries to match the response tokens, but it might match the instruction tokens instead

response_token_ids_start_idx = None
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
response_token_ids_start_idx = idx
break

This is because it breaks when the first token matches, but '### Response:\n' is encoded with [21017, 18261, 25, 198]., but it matches ### Instruction:\n ([21017, 46486, 25, 198]) instead.

To resolve the issue and if it is indeed that you intended to match the response tokens, you should consider the following snippet instead :)

            for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
                # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                if response_token_ids == examples[i]["input_ids"][idx:idx+len(response_token_ids)]:
                    response_token_ids_start_idx = idx  

Our related issue huggingface/trl#445 (comment)

CC @matthayes , WDYT?