script.py is a minimal, self-contained implementation of a conditional diffusion model. It learns to generate MNIST digits, conditioned on a class label. The neural network architecture is a small U-Net (pretrained weights also available in this repo). This code is modified from this excellent repo which does unconditional generation. The diffusion model is a Denoising Diffusion Probabilistic Model (DDPM).
Samples generated from the model.
The conditioning roughly follows the method described in Classifier-Free Diffusion Guidance (also used in ImageGen). The model infuses timestep embeddings
At training time,
Increasing
Samples produced with varying guidance strength,
Training for above models took around 20 epochs (~20 minutes).