/CoDi

[CVPR24] CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation

Primary LanguagePython

CoDi: Conditional Diffusion Distillation (CVPR24)

CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation
Kangfu Mei 1, 2, Mauricio Delbracio 2, Hossein Talebi 2, Zhengzhong Tu 2, Vishal M. Patel 1, Peyman Milanfar 2
1Johns Hopkins University
2Google Research

[Paper] [Project Page]

Disclaimer: This is not an official Google product. This repository contains an unofficial implementation. Please refer to the official implementation at https://github.com/google-research/google-research/tree/master/CoDi.

Disclaimer: All models in this repository were trained using publicly available data.

Introduction

CoDi can efficiently distill the sampling steps of a conditional diffusion model from an unconditional one (e.g. StableDiffsusion), enabling rapid generation of high-quality images (i.e. 1-4 steps) under various conditional settings (e.g. Inpainting, InstructPix2Pix, etc.).

teaser

On the standard real-world image super-resolution benchmark, we show that CoDi is capable of achieving 50 steps sampling performance in terms of FID and LPIPS with 4 steps only. It largely outperforms previous guided-distillation and consistency model. On the less challenge tasks such text-guided inpainting, we show that a new parameter-efficient distillation first proposed by us can even beat the original 50 steps sampling in the FID and LPIPS metrics.

performance

News

  • Feb-27-2024 We relase the canny-image-to-image checkpoint and demo with parameter-efficient CoDi. 🏁

  • Feb-26-2024 CoDi is accepted by CVPR24 🏁

  • Feb-22-2024 We relase the checkpoint and demo for parameter-efficient CoDi. 🏁

  • Dec-02-2023 We relase the training script of CoDi. 🏁

Detail Contents

  1. Training CoDi on HuggingFace Data
  2. Training CoDi on Your Own Data
  3. Testing CoDi on Canny Images
  4. Citations
  5. Acknowledgement

Note: The following instructions are modified from https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/README.md

Training CoDi on HuggingFace Data

All you need to do is to update the DATASET_NAME from the HuggingFace hub to train on (could be your own, possibly private, dataset). A good choice is to check the datasets under https://huggingface.co/spaces/jax-diffusers-event/leaderboard.

export HF_HOME="/data/kmei1/huggingface/"
export DISK_DIR="/data/kmei1/huggingface/cache"
export MODEL_DIR="stabilityai/stable-diffusion-2-1"
export OUTPUT_DIR="canny_model"
export DATASET_NAME="jax-diffusers-event/canny_diffusiondb"
export NCCL_P2P_DISABLE=1
export CUDA_VISIBLE_DEVICES=5
# export XLA_FLAGS="--xla_force_host_platform_device_count=4 --xla_dump_to=/tmp/foo"

python3 train_codi_flax.py \
 --pretrained_model_name_or_path $MODEL_DIR \
 --output_dir $OUTPUT_DIR \
 --dataset_name $DATASET_NAME \
 --load_from_disk \
 --cache_dir $DISK_DIR \
 --resolution 512 \
 --learning_rate 8e-6 \
 --train_batch_size 2 \
 --gradient_accumulation_steps 2 \
 --revision main \
 --from_pt \
 --mixed_precision bf16 \
 --max_train_steps 200_000 \
 --checkpointing_steps 10_000 \
 --validation_steps 100 \
 --dataloader_num_workers 8 \
 --distill_learning_steps 20 \
 --ema_decay 0.99995 \
 --onestepode uncontrol \
 --onestepode_control_params target \
 --onestepode_sample_eps vprediction \
 --cfg_aware_distill \
 --distill_loss consistency_x \
 --distill_type conditional \
 --image_column original_image \
 --caption_column prompt \
 --conditioning_image transformed_image \
 --report_to wandb \
 --validation_image "figs/control_bird_canny.png" \
 --validation_prompt "birds" \

Note that you may need to change the --image_column, --caption_column, and --conditioning_image according to your selected dataset. For example, you need to add these options for the jax-diffusers-event/canny_diffusiondb dataset according to this https://huggingface.co/datasets/jax-diffusers-event/canny_diffusiondb.

Training CoDi on Your Own Data

Data preprocessing

Here we demonstrate how to prepare a large dataset to train a ControlNet model that generates images conditioned on an image representation that only has edge information (using canny edge detection)

More specifically, we use an example script defined in https://github.com/huggingface/community-events/blob/main/jax-controlnet-sprint/dataset_tools/coyo_1m_dataset_preprocess.py:

  • Selects 1 million image-text pairs from an existing dataset COYO-700M. Downloads each image and use Canny edge detector to generate the conditioning image. Create a metafile that links all the images and processed images to their text captions.

  • Use the following command to run the example data preprocessing script. If you've mounted a disk to your TPU, you should place your train_data_dir and cache_dir on the mounted disk

python3 coyo_1m_dataset_preprocess.py \
 --train_data_dir="/data/dataset" \
 --cache_dir="/data" \
 --max_train_samples=1000000 \
 --num_proc=32

Once the script finishes running, you can find a data folder at the specified train_data_dir with the below folder structure:

data
├── images
│   ├── image_1.png
│   ├── .......
│   └── image_1000000.jpeg
├── processed_images
│   ├── image_1.png
│   ├── .......
│   └── image_1000000.jpeg
└── meta.jsonl

Training

All you need to do is to update the DATASET_DIR with the correct path to your data folder.

Here is an example to run a training script that will load the dataset from the disk

export HF_HOME="/data/huggingface/"
export DISK_DIR="/data/huggingface/cache"
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/data/canny_model"
export DATASET_DIR="/data/dataset"

python3 train_codi_flax.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --train_data_dir=$DATASET_DIR \
 --load_from_disk \
 --cache_dir=$DISK_DIR \
 --resolution=512 \
 --learning_rate=1e-5 \
 --train_batch_size=2 \
 --revision="non-ema" \
 --from_pt \
 --max_train_steps=500000 \
 --checkpointing_steps=10000 \
 --dataloader_num_workers=16 \
 --distill_learning_steps 50 \
 --distill_timestep_scaling 10 \
 --onestepode control \
 --onestepode_control_params target \
 --onestepode_sample_eps v_prediction \
 --distill_loss consistency_x \

Testing CoDi on Canny Images

Prompt: birds
Canny Image Ours w. 4-step sampling

We provide the pretrained canny-edge-to-image model according to the Controlnet experiments https://huggingface.co/lllyasviel/sd-controlnet-canny. Note that we are using the open-sourced data, i.e., jax-diffusers-event/canny_diffusiondb, and thus there are difference in the styles between ControlNet's result and ours.

export HF_HOME="/data/kmei1/huggingface/"
export DISK_DIR="/data/kmei1/huggingface/cache"
export MODEL_DIR="stabilityai/stable-diffusion-2-1"
export NCCL_P2P_DISABLE=1
export CUDA_VISIBLE_DEVICES=5

# download pretrained checkpoint and relocate it.
wget https://www.cis.jhu.edu/~kmei1/publics/codi/canny_99000.tar.fz && tar -xzvf canny_99000.tar.fz -C experiments

python test_canny.py

# or gradio user interface
python gradio_canny_to_image.py

The user interface looks like this 👇 demo

Citations

You may want to cite:

@article{mei2023conditional,
  title={CoDi: Conditional Diffusion Distillation for Higher-Fidelity and Faster Image Generation},
  author={Mei, Kangfu and Delbracio, Mauricio and Talebi, Hossein and Tu, Zhengzhong and Patel, Vishal M and Milanfar, Peyman},
  journal={arXiv preprint arXiv:2310.01407},
  year={2023}
}

Acknowledgement

The codes are based on Diffusers and HuggingFace. Please also follow their licenses. Thanks for their awesome works.