/jax-ddim

Jax/Flax implementation of Denoising Diffusion Implicit Models

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

Jax DDIM

Jax/Flax implementation of Denoising Diffusion Implicit Models

DDIM implementation following the keras example of Denoising Diffusion Implicit Models

Setup

Main dependencies

  • jax==0.3.14
  • flax==0.5.2
  • tensorflow==2.9.1
  • tensorflow-datasets==4.6.0
  • tensorboard==2.9.1

For instance, I recommend to use GCP Vertex Workbench (managed JupyterLab environment) with GPU accelerator. Vertex Workbench offers GPU environment and popular deep learning libraries.

Run experiment

Run train.py or train.ipynb. Trained model and Tensorboard logs are saved under outputs directory by default.

According to the Keras example, it is better to train at least 50 epochs for good results.

python train.py \
--epoch 50 \
<other arguments ...>

Results

Training loss and generated images for 50 epochs:

losses

images

Notes

This implementation follows the Keras example implementation. You can check the detailed tips and discussion here