Hierarchical Semi-Implicit Variational Inference with Application to Diffusion Model Acceleration NeurIPS 2023
Longlin Yu*, Tianyu Xie*, Yu Zhu*, Tong Yang, Xiangyu Zhang and Cheng Zhang
This repository contains the implementation for the paper Hierarchical Semi-Implicit Variational Inference with Application to Diffusion Model Acceleration (NeurIPS 2023).
HSIVI is a variational inference method that assumes the target density is accessible (e.g., the density function up to a constant or the score function is available). When used for diffusion model acceleration, HSIVI-SM does not directly target the generative model. Instead, it requires a sequence of auxiliary distributions that bridges between a simple distribution and the target distribution which is available given the learned score functions of diffusion models (see Example 2 in the original paper).
Before running our codes, please use the following commands to install the requirements.
# python=3.8.16
pip install --upgrade pip
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
For multi-node training, the following environment variables need to be specified: $IP_ADDR
is the IP address of the machine that will host the process with rank 0 during training (see here). $NODE_RANK
is the index of each node among all the nodes.
For the real image generation, we employ MNIST
, CIFAR10
, Celeba64
and ImageNet64
in this work. Note that HSIVI allows data-free training to accelerate diffusion model sampling.
For diffusion model acceleration, we use the pre-trained noise model $\bm{\epsilon}^(\bm{x},s)$ follows the UNet structure. More specifically, checkpoints of pre-trained $\bm{\epsilon}^(\bm{x} ,s)$ we used can be downloaded from
Dataset | Download link |
---|---|
Cifar10 | checkpoints |
Celeba64 | checkpoints |
Imagenet64 | checkpoints |
With the downloaded pre-trained models, train our models through main.py
main.py \
-cc <path of training configuration file> \
--root './' \
--mode train \
--n_gpus_per_node <number of GPUs per node> \
--training_batch_size <training batch size on a single GPU> \
--testing_batch_size 16 \
--sampling_batch_size 64 \
--independent_log_gamma <dis|use> \
--f_learning_times <number of f_net updates after per phi_net updates > \
--image_gamma <dis|use> \
--skip_type <type of beta > \
--n_discrete_steps <number of funciton evaluations (NFE) + 1> \
--phi_learning_rate <phi learning rate> \
--f_learning_rate <f learning rate> \
--n_train_iters <number of iteration> \
--pretrained_model <path of pretrained model> \
--workdir <path of working directory> \
--master_address 127.0.0.10 \
--master_port 4372
independent_log_gamma
: The meaning of'use' ('dis') is to use independnent (shared) log_gamma across different conditional layers.image_gamma
: The meaning of'dis' ('use') is to use the isotropic ('non-isotropic') conditional layers.skip_type
: Strategy of selecting the discrete time steps, 'uniform' and "quad" can be chosen.
For example, you can run the script to train HSIVI-SM on a single node with 8 GPUs (recommended) for CIFAR10
:
python main.py \
-cc configs/default_cifar.txt \
--root './' \
--mode train \
--n_gpus_per_node 8 \
--training_batch_size 16 \
--testing_batch_size 16 \
--sampling_batch_size 64 \
--independent_log_gamma dis \
--f_learning_times 20 \
--image_gamma use \
--skip_type quad \
--n_discrete_steps 11 \
--phi_learning_rate 0.000016 \
--f_learning_rate 0.00008 \
--n_train_iters 200000 \
--pretrained_model ./pretrained_model/target_epsilon_cifar10.pt \
--workdir ./work_dir/cifar10_10steps \
--master_address 127.0.0.10 \
--master_port 4372
We put the complete training scripts for different datasets separately in files ./run_celeba64.sh
( for Celeba64
), ./run_cifar10.sh
( for Cifar10
) and ./run_imagenet64.sh
( for ImageNet64
).
During the training process, the FID would be calculated precisely by sampling 50000 images. The pre-calculated FID statistics of CIFAR-10 has been put in ./inception/fid_stats_cifar10_train_pytorch.npz
.
After the training process, you can evaluate the FID of the generated samples of trained HSIVI model with the following script.
python main.py \
-cc <path of training configuration file> \
--root './' \
--mode eval \
--n_gpus_per_node 8 \
--training_batch_size <training batch size on a single GPU> \
--workdir <path of working directory>\
--eval_weight <path of checkpoints of trained epsilon model> \
--n_discrete_steps <number of funciton evaluations (NFE) + 1>
We referenced the training code of diffusion model acceleration for our models in the repository from Score-Based Generative Modeling with Critically-Damped Langevin Diffusion. As for the evaluation code for FID, we used the code provided by Elucidating the Design Space of Diffusion-Based Generative Models (EDM).
If you find this code useful for your research, please consider citing
@inproceedings{
yu2023hierarchical,
title={Hierarchical Semi-Implicit Variational Inference with Application to Diffusion Model Acceleration},
author={Longlin Yu and Tianyu Xie and Yu Zhu and Tong Yang and Xiangyu Zhang and Cheng Zhang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=ghIBaprxsV}
}