/k-diffusion

Karras et al. (2022) diffusion models for PyTorch

Primary LanguagePythonMIT LicenseMIT

k-diffusion

An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch. The patching method in Improving Diffusion Model Efficiency Through Patching is implemented as well.

Training:

To train models:

$ ./train.py --config CONFIG_FILE --name RUN_NAME

For instance, to train a model on MNIST:

$ ./train.py --config configs/config_mnist.json --name RUN_NAME

The configuration file allows you to specify the dataset type. Currently supported types are "imagefolder" (a folder with one subfolder per image class, the classes are currently ignored), "cifar10" (CIFAR-10), and "mnist" (MNIST).

Multi-GPU and multi-node training is supported with Hugging Face Accelerate. You can configure Accelerate by running:

$ accelerate config

on all nodes, then running:

$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME

on all nodes.

Enhancements/additional features:

  • k-diffusion models support progressive growing.

  • k-diffusion implements a sampler inspired by DPM-Solver and Karras et al. (2022) Algorithm 2 that produces higher quality samples at the same number of function evalutions as Karras Algorithm 2. It also implements a linear multistep sampler (comparable to PLMS).

  • k-diffusion supports CLIP guided sampling from unconditional diffusion models (see sample_clip_guided.py).

  • k-diffusion has wrappers for v-diffusion-pytorch, OpenAI diffusion, and CompVis diffusion models allowing them to be used with its samplers and ODE/SDE.

  • k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.

To do:

  • Anything except unconditional image diffusion models

  • Latent diffusion