princeton-nlp/SimPO

Repeated Addition of Assistant Turn in Prompt/Chosen/Rejected Text Using `apply_chat_template`

iseesaw opened this issue · 1 comments

When using the apply_chat_template function in the alignmentbook and the author's code, the assistant's turn is repeatedly added in the prompt, chosen, and rejected texts. This behavior might be causing discrepancies between SFT and DPO/SimPO, potentially influencing performance outcomes.

Code Snippet:

example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)

Observed Behavior:

The logs indicate that each section (prompt, chosen, rejected) includes an 'assistant' turn, which seems to be added automatically by the apply_chat_template function. This could lead to an accumulation of assistant responses, affecting the structure of the training data used for models.

Example Log Output:

2024-05-27 11:45:50 - INFO - __main__ - Prompt sample 3278 of the raw training set:

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

A middle aged man presents with progressive atrophy and weakness of hands and forearms. On examination he is found to have slight spasticity of the legs, generalized hyper-reflexia and increased signal in the cortico-spinal tracts on T2 weighted MRI. The most likely diagnosis is -

A. Multiple selerosis
B. Amyotrophic latetal sclerosis
C. Subacute combined degeneration
D. Progressive spinal muscular atrophy<|eot_id|><|start_header_id|>assistant<|end_header_id|>


2024-05-27 11:45:50 - INFO - __main__ - Chosen sample 3278 of the raw training set:

<|start_header_id|>assistant<|end_header_id|>

The patient presents with progressive atrophy and weakness of hands and forearms, which suggests a motor neuron disease. The presence of slight spasticity of the legs and generalized hyper-reflexia indicates involvement of the upper motor neurons. The increased signal in the cortico-spinal tracts on T2 weighted MRI also supports this finding.

Given these features, the most likely diagnosis is Amyotrophic Lateral Sclerosis (ALS), which is a disease that affects both upper and lower motor neurons.

So, the answer is B.<|eot_id|><|start_header_id|>assistant<|end_header_id|>


2024-05-27 11:45:50 - INFO - __main__ - Rejected sample 3278 of the raw training set:

<|start_header_id|>assistant<|end_header_id|>

Let's break down the clues:

* Progressive atrophy and weakness of hands and forearms: This suggests a peripheral nerve disorder, possibly affecting the motor neurons.
* Slight spasticity of the legs: This could be due to an upper motor neuron lesion, but it's a subtle finding.
* Generalized hyper-reflexia: This is a strong indication of an upper motor neuron lesion, as it suggests that the reflexes are being facilitated by the abnormal signal transmission.
* Increased signal in the cortico-spinal tracts on T2 weighted MRI: This is a key finding, as it suggests that the cortico-spinal tracts are affected, which is a hallmark of upper motor neuron lesions.

Considering these clues, the most likely diagnosis is:

So, the answer is B.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The prompt and chosen/rejected are directly concatenated:


    def build_tokenized_answer(self, prompt, answer):
        """
        Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
        It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
        Reference:
            https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        """

        full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
        prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]

I look forward to any insights or suggestions on how to address this issue. Thank you!

Solved. Because I used the initial chat template, which was later updated.

See Fix chat template to add generation prompt only if the option is selected (#9)