/trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Primary LanguagePythonMIT LicenseMIT

Transformer Reinforcement Learning X

trlx allows you to fine-tune 🤗 Hugging Face supported language models (gpt2, gpt-j, gpt-neo and gpt-neox based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.

You can read more about trlX in our documentation.

Installation

From Source

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda
pip install -e .

How to Train

You can train your model using a reward function or a reward-labeled dataset.

Using a reward function

import trlx

# optimize some reward function
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])

# model is a wrapper with some logit preprocessing
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Using a reward-labeled dataset

import trlx

# Steer a model with a collection of rated samples
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])

# model is a wrapper with some logit preprocessing
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Using 🤗 Accelerate to speed up the training

Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested)

accelerate config
accelerate launch examples/simulacra.py

For more usage see examples

Contributing

For development check out these guidelines and also read our docs

Acknowledgements

Thanks Leandro for starting the original trl