This repo is the official code for the paper Improved Techniques for Maximum Likelihood Estimation for Diffusion ODEs (ICML 2023).
The code implementation is based on google-research/vdm by Diederik P. Kingma.
We achieve state-of-the-art likelihood estimation results by diffusion ODEs on image datasets (2.56 on CIFAR-10, 3.43/3.69 on ImageNet-32) without variational dequantization or data augmentation, and 2.42 on CIFAR-10 with data augmentation.
Our improved techniques include:
(maximum likelihood training)
- Velocity parameterization with likelihood weighting
- Error-bounded second-order flow matching (finetuning)
- Variance reduction with importance sampling
(likelihood evaluation)
- Training-free truncated-normal dequantization
Following VDM, the code is based on Jax/Flax instead of PyTorch.
To run on a GPU machine, you need to find a corresponding version for your Python version and CUDA version at: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html. We recommend using python==3.8
and cuda==11.2
. To install the required libraries:
wget https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp38-none-manylinux2014_x86_64.whl
pip install jaxlib-0.3.15+cuda11.cudnn805-cp38-none-manylinux2014_x86_64.whl
pip install -r requirements.txt
The experiments are conducted on CIFAR10 and ImageNet32. We use the dataloader provided by tensorflow_datasets
.
As stated in the paper, there are two different versions of ImageNet32 dataset. For fair comparisons, we use both versions of ImageNet32, one is downloaded from https://image-net.org/data/downsample/Imagenet32_train.zip, following Flow Matching [3], and the other is downloaded from http://image-net.org/small/train_32x32.tar (old version, no longer available), following ScoreSDE and VDM. The former dataset applies anti-aliasing and is easier for maximum likelihood training.
For convenience, we adapt both versions of ImageNet32 to the format of tfrecord
. You can download them from https://drive.google.com/drive/folders/1VCGusEV-yGle_9I6ncd5CpmsUYUu-kid and put them under ~/tensorflow_datasets
.
There are four config files, corresponding to CIFAR10, CIFAR10 (with data augmentation), ImageNet32 (old version) and ImageNet32 (new version), respectively.
configs
├── cifar10.py
├── cifar10_aug.py
├── imagenet32.py
└── imagenet32_new.py
Some key configurations:
config.model = d(
schedule="VP", # noise schedule, "VP" (variance preserving)/"SP" (straight line)
second_order=False, # enable second order loss for finetuning
velocity=True, # enable velocity parameterization
importance=True, # enable importance sampling
dequantization="tn", # dequantization type, "u" (uniform)/"v" (variational)/"tn" (truncated-normal)
num_importance=1, # number of importance samples for likelihood evaluation
)
Train from scratch:
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_VP --mode=train
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_SP --mode=train --config.model.schedule="SP"
Continue training from a previous checkpoint:
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_VP --mode=train --config.ckpt_restore_dir=experiments/cifar10_VP/cifar10/20230101-012345/checkpoints-0
Finetune by the second order loss:
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_VP_finetuned --mode=train --config.ckpt_restore_dir=experiments/cifar10_VP/cifar10/20230101-012345/checkpoints-0 --config.model.second_order=True --config.training.num_steps_train=6200000
# For ImageNet32 (old version), we finetune with a batch size of 128 and accumulate the gradient every 4 steps to reduce the GPU memory requirements, which results in an effective batch size of 512.
python main.py --config=configs/imagenet32.py --workdir=experiments/imagenet32_VP_finetuned --mode=train --config.ckpt_restore_dir=experiments/imagenet32_VP/imagenet32/20230101-012345/checkpoints-0 --config.model.second_order=True --config.training.num_steps_train=2500000 --config.training.batch_size_train=128 --config.optimizer.gradient_accumulation=4
Compute the SDE/ODE likelihood (measured by BPD) of a checkpoint:
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_VP --mode=eval --checkpoint=experiments/cifar10_VP/cifar10/20230101-012345/checkpoints-0
python main.py --config=configs/cifar10.py --workdir=experiments/cifar10_VP --mode=sample --checkpoint=experiments/cifar10_VP/cifar10/20230101-012345/checkpoints-0
The sampled images will be saved to workdir/samples
in the format of .npz
.
CIFAR10 | ImageNet32 (new) | ImageNet32 (old) | CIFAR10 (with data augmentation) | |
---|---|---|---|---|
GPUs | 8 x NVIDIA A40 | 8 x NVIDIA A40 | 8 x NVIDIA A100 | 64 x NVIDIA A800 |
Steps | 6M | 2M | 2M | 1.95M |
Batch size | 128 | 128 | 512 | 1024 |
Steps (Finetuning) | +200k | +250k | +500k | 0 |
Batch size (Finetuning) | 128 | 128 | 128 (with gradient accumulation of every 4 steps) | \ |
NLL (BPD) | 2.56 | 3.43 | 3.69 | 2.42 |
You can download the checkpoints in different settings from https://drive.google.com/drive/folders/1cZ7zkWWooD2E8n_9slDMAA8w3KmkjzOC.
If you find the code useful for your research, please consider citing:
@inproceedings{zheng2023improved,
title={Improved Techniques for Maximum Likelihood Estimation for Diffusion ODEs},
author={Zheng, Kaiwen and Lu, Cheng and Chen, Jianfei and Zhu, Jun},
booktitle={International Conference on Machine Learning},
year={2023},
organization={PMLR}
}
There are some related papers which might also interest you:
[1] Diederik P Kingma, Tim Salimans, Ben Poole, Jonathan Ho. "Variational diffusion models". Advances in Neural Information Processing Systems, 2021.
[2] Cheng Lu, Kaiwen Zheng, Fan Bao, Jianfei Chen, Chongxuan Li, Jun Zhu. "Maximum likelihood training for score-based diffusion odes by high order denoising score matching". International Conference on Machine Learning, 2022.
[3] Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, Matt Le. "Flow matching for generative modeling". International Conference on Learning Representations, 2022.