/consistency-models

A JAX implementation of the continuous time formulation of Consistency Models

Primary LanguagePythonMIT LicenseMIT

consistency-models

consistency-models is a JAX implementation of the continuous time formulation of Consistency Models, which allows distillation of a diffusion model into a single-step generative model.

This code is a WORK IN PROGRESS, it is not done, it does not produce high quality results yet, I am releasing it due to general interest in consistency model implementations.

Requirements

pip install git+https://github.com/crowsonkb/jax-wavelets
pip install -r requirements.txt

Notes

train.py trains a diffusion model and a consistency model at the same time, and uses L_CD to continuously distill the EMA diffusion model into the consistency model. The consistency model is then used to generate samples in one step. This seems to work better than training the consistency model directly with L_CT.