/conditional-ddpm

A simple PyTorch implementation of conditional denoising diffusion probabilistic models (DDPM) on MNIST, Fashion-MNIST, and Sprite datasets

Primary LanguagePython

Conditional DDPM

Introduction

We implement a simple conditional form of Diffusion Model described in Denoising Diffusion Probabilistic Models, in PyTorch. Preparing this repository, we inspired by the course How Diffusion Models Work and the repository minDiffusion. While training, we use MNIST, FashionMNIST, and Sprite (see FrootsnVeggies and kyrise) datasets.

Setting Up the Environment

  1. Install Conda, if not already installed.
  2. Clone the repository
    git clone https://github.com/byrkbrk/diffusion-model.git
    
  3. In the directory diffusion-model, for macos, run:
    conda env create -f diffusion-env_macos.yaml
    
    For linux or windows, run:
    conda env create -f diffusion-env_linux_or_windows.yaml
    
  4. Activate the environment:
    conda activate diffusion-env
    

Training and Sampling

MNIST

To train the model on MNIST dataset from scratch,

python3 train.py --dataset-name mnist

In order to sample from our (pretrained) checkpoint:

python3 sample.py pretrained_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20

Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).

Fashion-MNIST

To train the model from scratch on Fashion-MNIST dataset,

python3 train.py --dataset-name fashion_mnist

In order to sample from our (pretrained) checkpoint, run:

python3 sample.py pretrained_fashion_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20

Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).

Sprite

To train the model from scratch on Sprite dataset:

python3 train.py --dataset-name sprite

In order to sample from our (pretrained) checkpoint, run:

python3 sample.py pretrained_sprite_checkpoint_49.pth --n-samples 225 --n-images-per-row 15

Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each three rows represents a class label (in total 15 rows and 5 classes).