/ddpo-pytorch

Reproduction of DDPO paper (RLHF for diffusion)

Primary LanguageJupyter Notebook

RLHF for Diffusion Models

This is an implementation of Training Diffusion Models with Reinforcement Learning. This is meant as an educational codebase, with lots of comments explaining the code and only basic features. It currently only implements LAION aesthetic classifier as a reward function, but more examples will be added soon.

Tutorial blog post coming soon

This codebase is just for educational purposes, another codebase for scalable training is being developed here.

Installation

git clone https://github.com/tmabraham/ddpo-pytorch.git
cd ddpo-pytorch
pip install -r requirements.txt

Usage

It's as simple as running:

python main.py

To save memory (you'll likely need it), use the arguments --enable_attention_slicing, --enable_xformers_memory_efficient_attention, and --enable_grad_checkpointing.

Results

Original samples: image

After training for 50 epochs: image