/Conditional_Diffusion_MNIST

Conditional diffusion model to generate MNIST. Minimal script. Based on 'Classifier-Free Diffusion Guidance'.

Primary LanguagePythonMIT LicenseMIT

Conditional Diffusion MNIST

script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits, conditioned on a class label. The neural network architecture is a small U-Net (pretrained weights also available in this repo). This code is modified from this excellent repo which does unconditional generation. The diffusion model is a Denoising Diffusion Probabilistic Model (DDPM).

Samples generated from the model.

The conditioning roughly follows the method described in Classifier-Free Diffusion Guidance (also used in ImageGen). The model infuses timestep embeddings $t_e$ and context embeddings $c_e$ with the U-Net activations at a certain layer $a_L$, via,

$a_{L+1} = c_e a_L + t_e.$

(Though in our experimentation, we found variants of this also work, e.g. concatenating embeddings together.)

At training time, $c_e$ is randomly set to zero with probability $0.1$, so the model learns to do unconditional generation (say $\psi(z_t)$ for noise $z_t$ at timestep $t$) and also conditional generation (say $\psi(z_t, c)$ for context $c$). This is important as at generation time, we choose a weight, $w \geq 0$, to guide the model to generate examples with the following equation,

$\hat{\epsilon}_{t} = (1+w)\psi(z_t, c) - w \psi(z_t).$

Increasing $w$ produces images that are more typical but less diverse.

Samples produced with varying guidance strength, $w$.

Training for above models took around 20 epochs (~20 minutes).