Aligning Large Language Models by On-Policy Self-Judgment,
Sangkyu Lee1,*, Sungdong Kim2,3,†, Ashkan Yousefpour1, Minjoon Seo3, Kang Min Yoo2,4, Youngjae Yu1,†
1Yonsei University, 2NAVER Cloud, 3KAIST AI, 4SNU AI Center
*Work done during internship at NAVER Cloud, †Corresponding Author
This is the official repository of "Aligning Large Language Models by On-Policy Self-Judgment" containing:
- Example training scripts for Judge-augmented Supervised Fine-tuning (JSFT)
- Example training scripts and Trainer for Self-Training by On-policy Judgment
- The helper function for performing Self-Rejection by Tournament
Please make sure to set up your environment with Python 3.10
, then follow the installation:
cd src
pip install -r requirements.txt
pip install -e .
cd ../
If you want to directly follow the example scripts of this repo, you also need to install flash-attn
as:
pip install flash-attn --no-build-isolation
In our entire implementations, we utilize Templates for Chat Models
on Supervised Fine-tuning and pairwise judgment. For pairwise judgments, we use the following nested dictionary structure for judgment templates:
{'principle_0': {'system': system_message_0,
'user': user_template_0,
'assistant': assistant_template_0},
'principle_1': {'system': system_message_1,
'user': user_template_1,
'assistant': assistant_template_1},
...
}
For each user template, formatting for context
, response_a
and response_b
is required. The context
corresponds to the context where the pairwise judgment between two responses is performed (e.g the instruction or prompt given to the assistant), and response_a
and response_b
correspond to each response on position A and B for pairwise judgment. You can check how these judgment templates are used to construct training examples for pairwise judgments in preprocess.py
. To utilizes rationales during training pairwise judgment, decision and rationale templates as in configs/uf_jsft.yaml
are required.
The included example training scripts can be used as:
# HH-Helpful
accelerate launch --config_file=configs/jsft_accelerate.yaml train_jsft.py --config=configs/hh_jsft.yaml
# UltraFeedback
accelerate launch --config_file=configs/jsft_accelerate.yaml train_jsft.py --config=configs/uf_jsft.yaml
# UltraFeedback with principles and rationale
accelerate launch --config_file=configs/jsft_accelerate.yaml train_jsft.py --config=configs/uf_jsft_pr.yaml
We provide SelfJudgeCollator
and SelfJudgeTrainer
which can be used as:
from self_judge import SelfJudgeCollator, SelfJudgeTrainer
...
data_collator = SelfJudgeCollator()
trainer = SelfJudgeTrainer(
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
**self_judge_args,
)
trainer.train()
SelfJudgeTrainer
is extended from DPOTrainer
of trl
library. Therefore, it shares similar features on usages. However, SelfJudgeTrainer
now only supports ["sigmoid", "hinge", "ipo"]
for loss_type
. Additionally, token_id_a
and token_id_b
which correspond to the judge token used on JSFT are required. For more arguments specifically required for the SelfJudgeTrainer
, please check the configurations in /configs
.
SelfJudgeCollator
and SelfJudgeTrainer
assume datasets with the following keys:
query
: a string where the model starts the continuation.context
: a string of context for performing the judgment between a sampled response pair.chosen
: a string corresponds to an example of chosen response. (only used on evaluations)rejected
: a string corresponds to an example of rejected response. (only used on evaluations)
The included example training scripts can be used as:
# HH-Helpful
accelerate launch --config_file=configs/self_judge_accelerate.yaml train_self_judge.py --config=configs/hh_self_judge.yaml
# UltraFeedback
accelerate launch --config_file=configs/self_judge_accelerate.yaml train_self_judge.py --config=configs/uf_self_judge.yaml
# UltraFeedback with principles and rationale
accelerate launch --config_file=configs/self_judge_accelerate.yaml train_self_judge.py --config=configs/uf_self_judge_pr.yaml
From the resulting model, you can perform rejection sampling with tournament_rejection_sampling
as:
from self_judge import tournament_rejection_sampling
...
contexts = ["context_0", "context_1", ...]
responses = [["response_0_0", "response_0_1"],
["response_0_1", "response_1_1", "response_1_2", ...],
...
]
winners = tournament_rejection_sampling(
model=model,
tokenizer=tokenizer,
judge_templates=judge_templates,
token_id_a=token_id_a,
token_id_b=token_id_b,
contexts=contexts,
responses=responses,
max_length=max_length,
judge_batch_size=judge_batch_sizes,
)
Note that each number of responses for context does not have to be equivalent, but length of contexts
and responses
should be equivalent. The returned winners
is List
of the best response on given contexts
through tournament based on the pairwise judgments.
@article{lee2024aligning,
title={Aligning Large Language Models by On-Policy Self-Judgment},
author={Lee, Sangkyu and Kim, Sungdong and Yousefpour, Ashkan and Seo, Minjoon and Yoo, Kang Min and Yu, Youngjae},
journal={arXiv preprint arXiv:2402.11253},
year={2024}
}