PyTorch implementation of Variational Diffusion Models.
PythonMIT
Variational Diffusion Models (VDM)
This is a PyTorch implementation of Variational Diffusion Models,
where the focus is on optimizing likelihood rather than sample quality,
in the spirit of probabilistic generative modeling.
This implementation should match the
official one in JAX.
However, the purpose is mainly educational and the focus is on simplicity.
So far, the repo only includes CIFAR10, and variance minimization
with the $\gamma_{\eta}$ network (see Appendix I.2 in the paper) is not
implemented (it's only used for CIFAR10 with augmentations and, according
to the paper, it does not have a significant impact).
Results
The samples below are from a model trained on CIFAR10 for 2M steps with gradient clipping and with a fixed noise
schedule such that $\log \mathrm{SNR}(t)$ is linear, with $\log \mathrm{SNR}(0) = 13.3$ and $\log \mathrm{SNR}(1) = -5$.
These samples are generated from the EMA model in 1000 denoising steps.
Without gradient clipping (as in the paper), the test set variational lower bound (VLB) is 2.715 bpd after 2M steps
(the paper reports 2.65 after 10M steps).
However, training is a bit unstable and requires some care (tendency to overfit)
and the train-test gap is rather large.
With gradient clipping, the test set VLB is slightly worse, but training seems more well-behaved.
Overview of the model
Diffusion process
Let $\mathbf{x}$ be a data point, $\mathbf{z}_t$ the latent variable at time $t \in [0,1]$, and
where $s(i) = \frac{i-1}{T}$ and $t(i) = \frac{i}{T}$.
We then choose the one-step denoising distribution to be equal to the
true denoising distribution given the data (which is available in
closed form) except that we substitute the unavailable data
with a prediction of the clean data at the previous time step:
One of the key components to reach SOTA likelihood is the
concatenation of Fourier features to $\mathbf{z}_t$ before feeding it into the
UNet. For each element $z_t^i$ of $\mathbf{z}_t$ (e.g., one channel of
a specific pixel), we concatenate:
$$f_n^{i} = \sin \left(2^n z_t^{i} 2\pi\right)$$
$$g_n^{i} = \cos \left(2^n z_t^{i} 2\pi\right)$$
with $n$ taking a set of integer values.
Assume that each scalar variable takes values:
$$\frac{2k + 1}{2^{m+1}} \ \text{ with }\ k = 0, ..., 2^m - 1 \ \text{ and }\ m \in \mathbb{N}.$$
E.g., in our case the $2^m = 256$ pixel values are $\left\{\frac{1}{512}, \frac{3}{512}, ..., \frac{511}{512} \right\}$.
The argument of $\sin$ and $\cos$ is then
which means the features have period $2^{m-n}$ in $k$.
Therefore, at very high SNR (i.e., almost discrete values with negligible noise), where
Fourier features are expected to be most useful to deal with fine details, we should choose
$n < m$, such that the period is greater than 1.
For the cosine, the condition is even stricter, because if $n = m-1$ then
$g_n^i = \cos\left(\frac{\pi}{2} + k\pi\right) = 0$.
Since in our case $m=8$, we take $n \leq 7$.
In the code we use $n \leq 6$ because images have twice the range
(between $\pm \frac{255}{256}$).
Below we visualize the feature values for pixel values 0 to 25, varying the
frequency $2^n$ with $n$ from 0 to 7. At $n=m-1=7$, the cosine features are constant,
and the sine features measure the least significant bit of the pixel value.
On clean data, any frequency $2^n$ with $n$ integer and $n > 7$ would
be useless (1 would be a multiple of the period).
Below are the sine features on the Mandrill image (and detail on the right) with smoothly increasing frequency
from $2^0$ to $2^{4.5}$.
Setup
The environment can be set up with requirements.txt. For example with conda:
Append --resume to the command above to resume training from the latest checkpoint.
See train.py for more training options.
Here we provide a sensible configuration for training on 2 GPUs in the file
accelerate_config.yaml. This can be modified directly, or overridden
on the command line by adding flags before "train.py" (e.g., --num_processes N
to train on N GPUs).
See the Accelerate docs for more configuration options.
After initialization, we print an estimate of the required GPU memory for the given
batch size, so that the number of GPUs can be adjusted accordingly.
The training loop periodically logs train and validation metrics to a JSONL file,
and generates samples.
This implementation is based on the VDM paper and official code. The code structure for training diffusion models with Accelerate is inspired by this repo.