/trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Primary LanguagePythonMIT LicenseMIT

DOI

Transformer Reinforcement Learning X

trlX allows you to fine-tune 🤗 Hugging Face supported language models of up to 20B parameters (such as gpt2, gpt-j, and gpt-neox, as well as T5 based models, including google/t5-v1_1 and google/flan-t5) using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.

You can read more about trlX in our documentation.

Want to collect human annotations for your RL application? Check out CHEESE!, our library for HiTL data collection.

Installation

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .

Examples

For more usage see examples. You can also try the colab notebooks below:

Description Link
Simulacra Example Open In Colab

How to Train

You can train a model using a reward function or a reward-labeled dataset.

Using a reward function

trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples])

Using a reward-labeled dataset

trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])

Trainers provide a wrapper over their underlying model

trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!)

trainer.save_pretrained('/path/to/output/folder/')

🩹 Warning: Only the AcceleratePPOTrainer can write HuggingFace transformers to disk with save_pretrained at the moment, as ILQL trainers require inference behavior currently unsupported by available transformers architectures.

Use 🤗 Accelerate to launch distributed training

accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.py

Use Ray Tune to launch hyperparameter sweep

python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py

Contributing

For development check out these guidelines and also read our docs

Acknowledgements

Many thanks to Leandro von Werra for contributing with trl, a library that initially inspired this repo.