This repo allows you to design new HumAn-centered LOss functions (HALOs) for aligning LLMs with offline human feedback at scale (read more in our technical report). It was used to create Archangel, the largest-ever suite of human-feedback-aligned LLMs, and has been tested at scales from 1B to 30B.
This repo draws from the excellently written DPO repo and has preserved many design choices from the original. Some of the key changes we introduced are:
- making data loading more modular, so that you can easily write your own dataloader
- making trainers more modular, so that each HALO has its own trainer subclass
- adding code for doing open-ended evaluation with GPT-4 as a judge
- supporting losses beyond SFT and DPO (including KTO, PPO (offline, off-policy variant), and SLiC)
Let's say we want to implement a new HALO called Kahneman-Tversky optimization (KTO). This is already implemented in this repo based on the details in our report, but let's pretend that it's not. What should we do?
-
First, create and activate the conda environment.
conda env create -f environment.yml
conda activate halos
-
Determine whether you need a new dataloader. KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable. This means we use
dataloader.UnpairedPreferenceDataLoader
. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP. If you wanted a custom dataloader, you would implement it in the same Python file by extending the baseDataLoader
class. -
Write a trainer in
trainers.py
. This should subclass eitherUnpairedPreferenceTrainer
orPairedPreferenceTrainer
depending on whether it uses pairs of preferences or not. If you need highly custom behavior that is not in either, then you can subclassBasicTrainer
directly.KTO is simple to implement: we just subclass
trainers.UnpairedPreferenceTrainer
astrainers.KTOTrainer
and overwrite the loss function definition. KTO has one hyperparameter, beta, which we can access viaself.config.loss.beta
:class KTOTrainer(UnpairedPreferenceTrainer): def loss(self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the Kahneman-Tversky loss for a batch of policy and reference model log probabilities. For each batch of n/2 chosen examples and n/2 rejected examples (belonging to n different inputs), calculate the loss as follows. If generation y ~ p_chosen, where x' ~ are the examples with rejected generations, we have the 'chosen' loss: L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x'))) If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss: L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)]) """ chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) chosen_logratios = (policy_chosen_logps - reference_chosen_logps) rejected_logratios = (policy_rejected_logps - reference_rejected_logps) losses = torch.cat((1 - F.sigmoid(self.config.loss.beta * (chosen_logratios - rejected_KL)), 1 - F.sigmoid(self.config.loss.beta * (chosen_KL - rejected_logratios))), 0) chosen_rewards = self.config.loss.beta * (policy_chosen_logps - reference_chosen_logps).detach() rejected_rewards = self.config.loss.beta * (policy_rejected_logps - reference_rejected_logps).detach() return losses, chosen_rewards, rejected_rewards
-
Add a file to the config/loss folder specifying the details of the loss:
name: kto beta: 0.1 # the temperature parameter for KTO; lower values mean we care less about the reference model trainer: KTOTrainer # implemented in trainers.py dataloader: UnpairedPreferenceDataLoader # already exists in dataloaders.py use_reference_model: true # true because the loss definition includes a reference model
-
Now we can start training a model! Let's train a Llama-7B model on the SHP, Anthropic HH, and Open Assistant datasets. Since the corresponding entry for Llama-7B is config/model/llama7b.yaml, we run a command with Hydra:
python train.py loss=kto model=llama7b datasets=[shp,hh,oasst] exp_name=kto_llama7b mode=train ++cache_dir=/data/models
which will align a Llama-7B model from scratch. If we want to align a model that we've already finetuned with the HALOs repo, we can add something like
++model.load_from=/data/models/sft_llama7b/LATEST/policy.pt
to the end of the command.That's it! Your model will be saved to
/data/models/kto_llama7b/LATEST/policy.pt
. -
Let's sample some generations from our newly trained model. The sampling configs are in either
config/config.yaml
or undermodels/
. We can sample 512 generations from our newly trained model in batches of 32 with the command, which will create a JSON file under samples/{exp_name}.json.python eval.py -c /data/models/kto_llama7b/config.yaml -m sample -n 512 -b 32
-
After setting
OPENAI_API_KEY
, we can evaluate our aligned model with GPT-4 with the following command, which compares the aligned model's generations to the human-chosen response in the data:python compare.py -f samples/kto_llama7b.json -mc 512 -bk chosen -ck policy -r result.jsonl
-
Do you support multi-node training?
No, currently the repo only supports single-node training. Multi-node training will be added at some point in the future. Every model in the Archangel suite was trained with 8 x A100 GPUs on a single node.
-
How do I save intermediate checkpoints?
Set intermediate_checkpoints to true in config/config.yaml or on the command line with ++config.intermediate_checkpoints=true. Every config.eval_every steps, a checkpoint will be saved in the experiment directory ($cache_dir/$exp_name).
-
Where do I find all the Archangel models?
They are all on the Huggingface Hub:
Model | PPO | DPO | KTO | SFT | SLIC | SFT+PPO | SFT+DPO | SFT+KTO |
---|---|---|---|---|---|---|---|---|
pythia1-4b | weights | weights | weights | weights | weights | weights | weights | weights |
pythia2-8b | weights | weights | weights | weights | weights | weights | weights | weights |
pythia6-9b | weights | weights | weights | weights | weights | weights | weights | weights |
pythia12-0b | weights | weights | weights | weights | weights | weights | weights | weights |
llama7b | weights | weights | weights | weights | weights | weights | weights | weights |
llama13b | weights | weights | weights | weights | weights | weights | weights | weights |
llama30b | weights | weights | weights | weights | weights | weights | weights | weights |
If you find this repo or the technical paper useful in your research, please feel free to cite our work:
@techreport{ethayarajh2023halos,
author = {Ethayarajh, Kawin and Xu, Winnie, and Jurafsky, Dan and Kiela, Douwe},
title = {Human-Centered Loss Functions (HALOs)},
institution = {Contextual AI},
note = {https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf},
year = {2023},
}