This is the official implementation of ICLR2023 Spotlight paper
by Xingchao Liu, Chengyue Gong, Qiang Liu from UT Austin
Rectified Flow is a novel method for learning transport maps between two distributions
Then, by a reflow operation, we iteratively straighten the ODE trajectories to eventually achieve one-step generation, with higher diversity than GAN and better FID than fast diffusion models.
An introductory website can be found here and the main idea is illustrated in the following figure:
Rectified Flow can be applied to both generative modeling and unsupervised domain transfer, as shown in the following figure:
For a more thorough inspection on the theoretical properties and its relationship to optimal transport, please refer to Rectified Flow: A Marginal Preserving Approach to Optimal Transport
We provide interactive tutorials with Colab notebooks to walk you through the whole pipeline of Rectified Flow. We provide two versions with different velocity models, neural network version and non-parametric version
The code for image generation is in ./ImageGeneration
. Run the following command first
cd ./ImageGeneration
The following instructions has been tested on a Lambda Labs "1x A100 (40 GB SXM4)" instance, i.e. gpu_1x_a100_sxm4
and Ubuntu 20.04.2(Driver Version: 495.29.05, CUDA Version: 11.5, CuDNN Version: 8.1.0).
We suggest to use Anaconda.
Run the following commands to install the dependencies:
conda create -n rectflow python=3.8
conda activate rectflow
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0
pip install jax==0.3.4 jaxlib==0.3.2
pip install numpy==1.21.6 ninja==1.11.1 matplotlib==3.7.0 ml_collections==0.1.1
Run the following command to train a 1-Rectified Flow from scratch
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow
-
--config
The configuration file for this run. -
--eval_folder
The generated images and other files for each evaluation during training will be stroed in./workdir/eval_folder
. In this command, it is./logs/1_rectified_flow/eval/
-
---mode
Mode selection formain.py
. Select fromtrain
,eval
andreflow
.
We follow the evaluation pipeline as in Score SDE. You can download cifar10_stats.npz
and save it to assets/stats/
.
Then run
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 2
which uses a batch size of 1024 to sample 50000 images, starting from checkpoint-2.pth, and computes the FID and IS.
To prepare data for reflow, run the following command
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_generate_data.py --eval_folder eval --mode reflow --workdir ./logs/tmp --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/1_rectified_flow/" --config.reflow.total_number_of_samples 100000 --config.seed 0
-
--config.reflow.last_flow_ckpt
The checkpoint for data generation. -
--config.reflow.data_root
The location where you would like the generated pairs to be saved. The$(Z_0, Z_1)$ pairs will be saved to./data_root/seed/
-
--config.reflow.total_number_of_samples
The total number of pairs you would like to generate -
--config.seed
The random seed. Change random seed to perform data generation in parallel.
For CIFAR10, we suggest to generate at least 1M pairs for reflow.
After the data pairs are generated, run the following command to train 2-rectified flow
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_train.py --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/1_rectified_flow/"
This command fine-tunes the checkpoint of 1-Rectified Flow with the data pairs generated in the last step, and save the logs of 2-rectified flow to ./logs/2_rectified_flow
.
2-Rectified Flow should have a much better performance when using one-step generation
To evaluate with step N=1, run
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/2_rectified_flow --config.sampling.use_ode_sampler "euler" --config.sampling.sample_N 1 --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 2
where sample_N
refers to the number of sampling steps.
We can further improve the quality of 2-Rectified Flow in one-step generation with distillation.
Before distillation, we need new data pairs from 2-Rectified Flow. This can be simply done with
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_generate_data.py --eval_folder eval --mode reflow --workdir ./logs/tmp --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/" --config.reflow.total_number_of_samples 100000 --config.seed 0
Then we can distill 2-Rectified Flow with
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k=1.py --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow_k=1_distill --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/"
Similarly, we can distill 2-Rectified Flow for k-step generation (k>1) with
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k_g_1.py --eval_folder eval --mode reflow --config.reflow.reflow_t_schedule 2 --workdir ./logs/2_rectified_flow_k=2_distill --config.reflow.last_flow_ckpt "./logs/2_rectified_flow/checkpoints/checkpoint-10.pth" --config.reflow.data_root "./assets/reflow_data/2_rectified_flow/"
Here, we use k=2 as an example. Change --config.reflow.reflow_t_schedule
to accomodate for different k.
To save storage space and simplify the training pipeline in reflow, data pairs can also be generated by a teacher model during training, which takes longer training time. To perform reflow with online data generation, run
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_train_online.py --eval_folder eval --mode reflow --workdir ./logs/2_rectified_flow --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth"
Then, to distill 1-Rectified Flow with online data generation, run
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_reflow_distill_k=1_online.py --eval_folder eval --mode reflow --workdir ./logs/1_rectified_flow_k=1_distill --config.reflow.last_flow_ckpt "./logs/1_rectified_flow/checkpoints/checkpoint-10.pth"
We provide code and pre-trained checkpoints for high-resolution generation on four
We use CelebA-HQ as an example. To train 1-rectified flow on CelebA-HQ, run
python ./main.py --config ./configs/rectified_flow/celeba_hq_pytorch_rf_gaussian.py --eval_folder eval --mode train --workdir ./logs/celebahq
To sample images from pre-trained rectified flow, run
python ./main.py --config ./configs/rectified_flow/celeba_hq_pytorch_rf_gaussian.py --eval_folder eval --mode eval --workdir ./logs/celebahq --config.eval.enable_figures_only --config.eval.begin_ckpt 10 --config.eval.end_ckpt 10 --config.training.data_dir YOUR_DATA_DIR
The images will be stored in ./logs/celebahq/eval/ckpt/figs
As an example, to use pre-trained checkpoints, download the checkpoint_8.pth
from CIFAR10 1-Rectified Flow, put it in ./logs/1_rectified_flow/checkpoints/
, then run
python ./main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode eval --workdir ./logs/1_rectified_flow --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.eval.begin_ckpt 8
The pre-trained checkpoints are listed here:
If you use the code or our work is related to yours, please cite us:
@article{liu2022flow,
title={Flow straight and fast: Learning to generate and transfer data with rectified flow},
author={Liu, Xingchao and Gong, Chengyue and Liu, Qiang},
journal={arXiv preprint arXiv:2209.03003},
year={2022}
}
A Large portion of this codebase is built upon Score SDE.