This repo contains PyTorch implementation of the paper "Improving Probabilistic Diffusion Models With Optimal Covariance Matching"
by Zijing Ou, Mingtian Zhang, Andi Zhang, Tim Xiao, Yingzhen Li, and David Barber.
we leverage the covariance moment matching technique and introduce a novel method for learning the diagonal covariances of diffusion models. Unlike traditional data-driven covariance approximation approaches, our method involves directly regressing the optimal analytic covariance using a new, unbiased objective named Optimal Covariance Matching (OCM). This approach can significantly reduce the approximation error in covariance prediction. We demonstrate how our method can substantially enhance the sampling efficiency, recall rate and likelihood of both diffusion models and latent diffusion models.
Our implementation is based on the Extended Analytic-DPM repository. To set up the environment, please follow the installation instructions provided in that repository. The main functionality of our code closely mirrors the original repo, and we provide detailed usage instructions below.
To train the model, you can use the following command:
python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
pretrained_path
is the path to a pretrained diffusion probabilistic model (DPM). Here are the links to the pretrained models: CIFAR10 (LS), CIFAR10 (CS), CelebA64, ImageNet64, LSUN-Bedroom.dataset
represents the training dataset, one of <cifar10
|celeba64
|imagenet64
|lsun_bedroom
>.workspace
is the place to put training outputs, e.g., logs and checkpoints.train_hparams
specify other hyperparameters used in training.
We provide the train_hparams
used in training for our models on each dataset:
- CIFAR10 (LS):
--method pred_eps_hes_pretrained
- CIFAR10 (CS):
--method pred_eps_hes_pretrained --schedule cosine_1000
- CelebA64:
--method pred_eps_hes_pretrained
- ImageNet64:
--method pred_eps_hes_pretrained --mode complex
- LSUN-Bedroom:
--method pred_eps_hes_pretrained --mode complex
As an example, to train the CIFAR10 (LS) model, you can run:
python run_train.py --pretrained_path path/to/pretrained_dpm --dataset cifar10 --workspace path/to/working_directory --method pred_eps_hes_pretrained
To evaluate the model, you can use the following command:
python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory --phase phase --sample_steps sample_steps --batch_size batch_size --method pred_eps_hes_pretrained $eval_hparams
pretrained_path
is the path to a model to evaluate. We provide all checkpoints trained with the proposed OCM approach here.dataset
represents the dataset the model is trained on, one of <cifar10
|celeba64
|imagenet64
|lsun_bedroom
>.workspace
is the place to put evaluation outputs, e.g., logs, samples and bpd values.phase
specifies running FID or likelihood evaluation, one of <sample4test
|nll4test
>.sample_steps
is the number of steps to run during inference, the samller this value the faster the inference.batch_size
is the batch size, e.g., 500.eval_hparams
specifies other optional hyperparameters used in evaluation.
We provide eval_hparams
for the FID and NLL results in this paper.
- FID Evaluation (DDPM)
- CIFAR10 (LS):
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2
- CIFAR10 (CS):
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000
- CelebA64:
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2
- ImageNet64:
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex
- LSUN-Bedroom:
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex
- CIFAR10 (LS):
- FID Evaluation (DDIM)
- CIFAR10 (LS):
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0
- CIFAR10 (CS):
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000
- CelebA64:
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0
- ImageNet64:
-rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex
- LSUN-Bedroom:
--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex
- CIFAR10 (LS):
- NLL Evaluation
- CIFAR10 (LS):
--rev_var_type optimal
- CIFAR10 (CS):
--rev_var_type optimal --schedule cosine_1000
- CelebA64:
--rev_var_type optimal
- ImageNet64:
--rev_var_type optimal --mode complex
- CIFAR10 (LS):
This link provides precalculated FID statistics on CIFAR10, CelebA64, ImageNet64 and LSUN-Bedroom. They are computed following Appendix F.2 in Analytic-DPM.
As an example, to evaluate the FID (DDPM) result of the CIFAR10 (LS) model, you can run:
python run_eval.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory --phase sample4test --sample_steps sample_steps --batch_size batch_size --method pred_eps_hes_pretrained --rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2
To evaluate the NLL result of the CIFAR10 (LS) model, you can run:
python run_eval.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory --phase nll4test --sample_steps sample_steps --batch_size batch_size --method pred_eps_hes_pretrained --rev_var_type optimal
😄If you find this repo is useful, please consider to cite our paper:
@article{ou2024improving,
title={Improving Probabilistic Diffusion Models With Optimal Covariance Matching},
author={Ou, Zijing and Zhang, Mingtian and Zhang, Andi and Xiao, Tim Z and Li, Yingzhen and Barber, David},
journal={arXiv preprint arXiv:2406.10808},
year={2024}
}