This repo contains a reference implementation of the TDPO algorithm for training language models from preference data, as described in the paper Token-level Direct Preference Optimization (ICML 2024). Our implementation is based on DPO, and follows the same usage guidelines.
The TDPO pipeline has two stages:
- Run supervised fine-tuning (SFT) on the dataset(s) of interest. Generally,
$(x, y_w)$ from the preference dataset is directly used as the supervised fine-tuning target. - Run preference learning on the model from step 1, using preference data (ideally from the same distribution as the SFT examples). The dataset is generally composed of $\mathcal{D} = {(x, y_w, y_l)i}{i=1}^N$, where
$x$ represents the prompt,$y_w$ and$y_l$ denote the preferred and dispreferred completion.
During training, we generally train for one episode in the SFT stage, while in the RL Fine-tuning stage, we run multiple episodes (e.g., three episodes) to enhance the performance of our algorithm.
The files in this repo are:
train.py
: the main entry point for training (either SFT or TDPO preference-based training)trainers.py
: the trainer classes (e.g., implementing the loop of learning as well as multi-GPU logic)utils.py
: some convenience functions used by multiple other filespreference_datasets.py
: dataset processing logic for both SFT and TDPO preference-based training; this is where you'll need to make some additions to train on your own data
The code here supports any causal HuggingFace model- look at our examples in config/model
to add your own. Adding your own datasets is also easy. See the README section on adding datasets.
Let's work through a complete example training pythia 2.8B model on the Anthropic-HH dataset.
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16
For running TDPO2, we recommend the following command:
python -u train.py model=pythia28 datasets=[hh] loss=tdpo loss.alpha=0.5 loss.beta=0.1 exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt
To run TDPO1, we only need to pass the additional parameter loss.if_tdpo2=false
:
python -u train.py model=pythia28 datasets=[hh] loss=tdpo loss.beta=0.1 loss.if_tdpo2=false exp_name=anthropic_tdpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt
When the learning rate/lr is low, we recommend the TDPO1 algorithm; conversely, for higher learning rates, the TDPO2 algorithm is preferable.
We have included the training curve from wandb here. Additionally, we have also provided the comparison results with DPO on the IDMb experiment, as shown below.
For more experimental details and information, please refer to our paper.
Many thanks to the contributors of DPO for their valuable contributions to the RLHF community. For more detailed information, please refer to the DPO.
If TDPO or this repository is useful in your research, you can use the following BibTeX entry to cite our paper:
@misc{zeng2024tokenlevel,
title={Token-level Direct Preference Optimization},
author={Yongcheng Zeng and Guoqing Liu and Weiyu Ma and Ning Yang and Haifeng Zhang and Jun Wang},
year={2024},
eprint={2404.11999},
archivePrefix={arXiv},
primaryClass={cs.CL}
}