/diffusion_restart_sampling

Code for NeurIPS 2023 paper "Restart Sampling for Improving Generative Processes"

Primary LanguagePython

Restart sampling

Restart Sampling for Improving Generative Processes

by Yilun Xu, Mingyang Deng, Xiang Cheng, Yonglong Tian, Ziming Liu and Tommi S. Jaakkola

Generative processes that involve solving differential equations, such as diffusion models, frequently necessitate balancing speed and quality. ODE-based samplers are fast but plateau in performance while SDE-based samplers deliver higher sample quality at the cost of increased sampling time. We attribute this difference to sampling errors: ODE-samplers involve smaller discretization errors while stochasticity in SDE contracts accumulated errors. Based on these findings, we propose a novel sampling algorithm called Restart in order to better balance discretization errors and contraction.

Empirically, Restart sampler surpasses previous diffusion SDE and ODE samplers in both speed and accuracy. Restart not only outperforms the previous best SDE results, but also accelerates the sampling speed by 10-fold / 2-fold on CIFAR-10 / ImageNet $64{\times} 64$. In addition, it attains significantly better sample quality than ODE samplers within comparable sampling times. Moreover, Restart better balances text-image alignment/visual quality versus diversity than previous samplers in the large-scale text-to-image Stable Diffusion model pre-trained on LAION $512{\times} 512$.

schematic

Results on Stable Diffusion v1.5:

schematic


Standard Benchmarks (CIFAR-10, ImageNet-64)

The working directory for standard benchmarks is under ./benchmarks

1. Preparing datasets and checkpoints

CIFAR-10: Download the CIFAR-10 python version and convert to ZIP archive:

python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
    --dest=datasets/cifar10-32x32.zip
python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz

ImageNet: Download the ImageNet Object Localization Challenge and convert to ZIP archive at 64x64 resolution:

python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
    --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz

Alternatively, you could consider downloading the FID statistics at CIFAR-10-FID and ImageNet-$64\times 64$-FID.

Please download the CIFAR-10 or ImageNet checkpoints from EDM repo or PFGM++ repo. For example,

Dataset Method Path
CIFAR-10 VP unconditional path
CIFAR-10 EDM unconditional path
CIFAR-10 PFGM++ ($D=2048$) unconditional path
ImageNet $64\times 64$ EDM conditional path

2. Generate

Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using torchrun. Before generation, please make sure the checkpoint is downloaded in the ./benchmarks/imgs folder

torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \
--restart_info='{restart_config}' --S_min=0.01 --S_max=1 --S_noise 1.003 \
--steps={steps} --seeds=00000-49999 --name={name} (--pfgmpp=1) (--aug_dim={D})


restart_config: configuration for Restart (details below)
steps: number of steps in the main backward process, default=18
name: name of experiments (for FID evaluation)
pfgmpp: flag for using PFGM++
D: augmented dimension in PFGM++

The restart_info is in the format of $\lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}] \rbrace_{i=0}^{l-1}$ , such as {"0": [3, 2, 0.06, 0.30]}. Please refer to Table 3 (CIFAR-10) and Table5 (ImageNet-64) for detailed configuration. For example, on uncond. EDM cond. ImageNet-64, with NFE=203, FID=1.41, the command line is:

torchrun --standalone --nproc_per_node=8 generate_restart.py --outdir=./imgs \
--restart_info='{"0": [4, 1, 19.35, 40.79], "1": [4, 1, 1.09, 1.92], "2": [4, 5, 0.59, 1.09], "3": [4, 5, 0.30, 0.59], "4": [6, 6, 0.06, 0.30]}' --S_min=0.01 --S_max=1 --S_noise 1.003 \
--steps=36 --seeds=00000-49999 --name=imagenet_edm

We also provide the extentive Restart configurations in params_cifar10_vp.txt, params_imagenet_edm.txt, corresponding to Table 3 (CIFAR-10) and Table5 (ImageNet-64) respectively. Each line in these txt is in the form of $N_{\textrm{main}} \quad \lbrace i: [N_{\textrm{Restart},i}, K_i, t_{\textrm{min}, i}, t_{\textrm{max}, i}]\rbrace_{i=0}^{l-1}$. To sweep the Restart configurations in the txt files, please run

python3 hyperparams.py --dataset {dataset} --method {method}

dataset: cifar10 | imagenet
method: edm | vp | pfgmpp

The above sweeping will reproduce the results in the following figure (Fig 3 in the paper):

schematic

3. Evaluation

For FID evaluation, please run:

python fid.py  ./imgs/imgs_{name} stats_path

name: name of experiments (specified in geneation command line)
stats_path: path to FID statistics, such as ./cifar10-32x32.npz or ./imagenet-64x64.npz

Stable Diffusion

TODO: merge into the diffuser repo.

💡 Since we adopt the huggingface/diffuser repo, we keep all their structure and files. But we only apply Restart sampler in the /diffuser/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py files. More exact pointer is Restart code block in Stable diffusion.

The working directory for standard benchmarks is under ./diffuser

schematic

1. Data processing:

  • Step 1: Follow the instruction on top of data_process.py

  • Step 2: Run python3 data_process.py to randomly sampled 5K image-text pair from COCO validation set.

  • Step 3: Calculate FID statistics python fid_edm.py ref --data=path-to-coco-subset --dest=./coco.npz

2. Generate

torchrun --rdzv_endpoint=0.0.0.0:1201 --nproc_per_node=8 generate.py 
	--steps: number of sampling steps (default=50) 
	--scheduler: baseline method (DDIM | DDPM | Heun)
	--save_path: path to save images
	--w: classifier-guidance weight ({2,3,5,8})
	--name: name of experiments
	--restart

If you would like to visualize the images given prompt, run:

python3 visualization.py --prompt {prompt} --w {w} --steps {steps} --scheduler {scheduler}  (--restart)

prompt: text prompt. defautlt='a photo of an astronaut riding a horse on mars'
steps: number of sampling steps (default=50) 
scheduler: baseline method (DDIM | DDPM)
w: classifier-guidance weight ({2,3,5,8})

3. Evaluation

  • FID score

    python3 fid.py {path} ./coco.npz
    
    path: path to the directory of generated image
  • Aesthetic score & CLIP score

    python3 eval_clip_score.py --csv_path {path}/subset.csv --dir_path {path}
    
    path: path to the directory of generated image