/consistency-models

A Toolkit for OpenAI's Consistency Models.

Primary LanguagePythonMIT LicenseMIT

Consistency Models 🌃

Single-step image generation with Consistency Models.



Consistency Models are a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality.


Installation

$ pip install consistency

Note

You don't need to install consistency for just trying things out:

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "consistency/cifar10-32-demo",
    custom_pipeline="consistency/pipeline",
)

pipeline().images[0]  # Super Fast Generation! 🤯

Try it yourself!


Quickstart

Basic

Just wrap your favorite U-Net with Consistency.

import torch
from diffusers import UNet2DModel
from consistency import Consistency
from consistency.loss import PerceptualLoss

consistency = Consistency(
    model=UNet2DModel(sample_size=224),
    loss_fn=PerceptualLoss(net_type=("vgg", "squeeze"))
)

samples = consistency.sample(16)

# multi-step sampling, sample from the ema model
samples = consistency.sample(16, steps=5, use_ema=True)

Consistency is self-contained with the training logic and all necessary schedules.

You can train it with PyTorch Lightning's Trainer 🚀

from pytorch_lightning import Trainer

trainer = Trainer(max_epochs=8000, accelerator="auto")
trainer.fit(consistency, some_dataloader)

Push to HF Hub

Provide your model_id and token to Consistency.

consistency = Consistency(
    model=UNet2DModel(sample_size=224),
    loss_fn=PerceptualLoss(net_type=("vgg", "squeeze"))
    model_id="your_model_id",
    token="your_token"  # Not needed if logged in via huggingface-cli
    push_every_n_steps=10000,
)

You can safely drop consistency afterwards. Good luck! 🤞

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained(
    "your_name/your_model_id",
    custom_pipeline="consistency/pipeline",
)

pipeline().images[0]

A complete example can be found in here or in this colab notebook.

Checkout this Wandb workspace for some experiment results.


Available Models

model_id dataset
consistency/cifar10-32-demo cifar10

If you've trained some checkpoints using consistency, share with us! 🤗


Documentation

In progress... 🛠


Reference

@misc{https://doi.org/10.48550/arxiv.2303.01469,
  doi       = {10.48550/ARXIV.2303.01469},
  url       = {https://arxiv.org/abs/2303.01469},
  author    = {Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
  keywords  = {Machine Learning (cs.LG), Computer Vision and Pattern Recognition (cs.CV), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title     = {Consistency Models},
  publisher = {arXiv},
  year      = {2023},
  copyright = {arXiv.org perpetual, non-exclusive license}
}