/Conditional-Diffusion-Models

Primary LanguagePythonApache License 2.0Apache-2.0

Conditional Diffusion Models

This is an easy implementation based on the repository (https://github.com/dome272/Diffusion-Models-pytorch). The Diffusion model is based on DDPM paper, and the conditioning idea is taken from Classifier-Free Diffusion Guidance.


Train the Model on Fashion-MNIST:

  1. Configure Hyperparameters in main.py
  2. Set dataset usage in utils.py
  3. python main.py

Sampling

The generate.py file shows how to sample images using the model's saved checkpoints in "models/DDPM_conditional".

python generate.py

Result

I just used a CPU to train the model for 6 epochs, and got the following results for 2 generated samples:

image info

Compared to the target image:

image info

The results definetely can be improved with long training and tuning time.

Evaluation

To quantitatively evaluated the generated results, some metrics can be used, such as FID, CLIP. Due to the time limit, the FID and CLIP metrics are not implemented yet.