/denoising-diffusion-mnist

Training and sampling from a denoising diffusion model on the MNIST dataset

Primary LanguagePythonMIT LicenseMIT

Applying Denoising Diffusion Model on MNIST dataset

original MNIST images
10 example images from the MNIST dataset

Seed = 1 Seed = 2
Diffusion model trained on MNIST generating images

Forward Diffusion Process

original MNIST images
Images after 200 timesteps applying Gaussian noise

The forward process is a Markov chain of sequentially adding Gaussian noise for $T$ timesteps to the sample image $x_0$$q(x)$, resulting in a sequence of noisy samples $x_1, ..., x_T$. The step size is controlled by a variance schedule βt.

original MNIST images

Reverse Diffusion process

Correctly reversing above noising process would lead to a recreation of the true sample image. However, that would require $q(x_{t-1}|x_{t})$ for which we'd need the entire dataset. Thus, we train a model $p$θ approximating this conditional probability. Running the reverse diffusion process then looks like this:

original MNIST images

From a sufficiently trained model we can then input random Gaussian noise to generate new images resembling the original dataset. Sample generated images are depicted below representing handwritten digits.

original MNIST images

Training

To train the model yourself you can run train.py with following arguments -e [num_epochs] -l [learning_rate] -b [batch_size]. Training for 5 epochs took around 3 hours on my machine, you can find according checkpoint file in ./models/checkpoint_5ep.pth.

References