/rfpp

The codebase of our paper "Improving the Training of Rectified Flows", NeurIPS 2024

Primary LanguagePythonBSD 3-Clause Clear LicenseBSD-3-Clause-Clear

2-Rectified Flow++

This is the codebase of our paper Improving the Training of Rectified Flows.

Setup

Tested environment: PyTorch >= 2.0.0, Linux.

You can install the required packages by running pip install -r requirements.txt.

Getting started with pre-trained 2-rectified flow++ models

We provide pre-trained 2-rectified flow++ models here.

Generation

To generate 50,000 samples from the pre-trained CIFAR-10 model using 2 GPUs, run:

python generate.py --gpu 0,1 --dir runs/test \
--solver euler --N 5 --sampler new --num_samples 50000 --batchsize 512 \
--ckpt CKPT_PATH --config configs_unet/cifar10_ve_aug.json

where CKPT_PATH is the path to the pre-trained model checkpoint. To evaluate FID, run:

torchrun --standalone --nproc_per_node=1 fid.py calc --images=runs/test/samples --ref=PATH_TO_cifar10-32x32.npz --num 50000;

You can download inception statistics from here.

Image to image translation

Lion (291) -> Tiger (292):

python generate.py --gpu 0 --dir runs/test-inversion --solver euler --N 4 --N_decode 2 --batchsize 6 --ckpt imagenet-configF.pth --config configs_unet/imagenet64.json --action inversion \
--data_path imagenet-samples/n02129165/ --num_samples 6 --sampler new --label_inv 291 --label_rec 292

Training

Generating synthetic pairs

First, download the pre-trained EDM checkpoints from here. Then, generate synthetic pairs using the pre-trained EDM models by running:

# Generate synthetic pairs for CIFAR-10

python generate_edm.py --gpu 0,1,2,3 --dir  runs/cifar-pair --num_samples 1000000 --batchsize 2048 --config configs_unet/cifar10_ve_aug.json --ckpt edm_cifar_ve_uncond.pth  --N 18 --ext zip

# Generate synthetic pairs for ImageNet 64x64

python generate.py --gpu 0,1,2,3 --dir  runs/imagenet-pair --num_samples 5000000 --batchsize 512 --config configs_unet/imagenet64.json --ckpt edm_imagenet64_ve_cond.pth --N 40 --ext zip

This will generate samples.zip (data), z.zip (noise), and images_labels.csv (class labels) in the specified directory.

Training 2-rectified flow++

Run:

export WORLD_SIZE=1
export NODE_RANK=0
export MASTER_ADDR=localhost

# Train CIFAR-10 using Config F
python train.py --gpu 0,1,2,3 --dir runs/test-cifar --warmup_steps 5000 --learning_rate 2e-4 --batchsize 512 --iterations 800001 --config_de configs_unet/cifar10_ve_aug.json \
--ema_decay 0.9999 \
--im_dir runs/cifar-pair/samples.zip --z_dir runs/cifar-pair/z.zip \
--ckpt edm_cifar_ve_uncond.pth --a 4 --loss_type lpips-huber --lpips_divt --port 12354 \
--compile

# Train ImageNet 64x64 using Config E
python train.py --gpu 0,1,2,3,4,5,6,7 --dir runs/test-imagenet --learning_rate 1e-4 --warmup_steps 2500 \
--batchsize 512 --effective_batchsize 2048 --iterations 700001 --ema_decay 0.9999 \
--config_de HVAE/configs_unet/imagenet64.json \
--im_dir runs/imagenet-pair/samples.zip --z_dir /runs/imagenet-pair/z.zip \
--ckpt edm_imagenet64_ve_cond.pth --a 4 --loss_type lpips-huber --port 12355 \
--loss_scaling 32768 --compile

# Train AFHQ using Config F
python train.py --gpu 0,1,2,3,4,5,6,7 --dir runs/test-afhq --learning_rate 2e-4 --warmup_steps 5000 --batchsize 256 --iterations 1000001 --config_de configs_unet/afhq64_ve_aug.json \
--ema_decay 0.9999 \
--im_dir runs/afhq-pair/samples.zip --z_dir runs/afhq-pair/z.zip \
--ckpt edm_afhq64_ve_cond.pth --a 4 --loss_type lpips-huber --port 12356 --compile

Citation

@article{lee2024improving,
  title={Improving the Training of Rectified Flows},
  author={Lee, Sangyun and Lin, Zinan and Fanti, Giulia},
  journal={arXiv preprint arXiv:2405.20320},
  year={2024}
}