/diffusion_models_distillation

Papers and implementations of distillation for diffusion models

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

Diffusion Models Distillation

This repository shows an implementation of distilling diffusion models into fewer sampling steps based on On Distillation of Guided Diffusion Models and Progressive Distillation for Fast Sampling of Diffusion Models. The implementation is based on diffusers, and distills a classifier-free guidance model on imagenet into 1/2 sampling steps.

For more distillation papers about diffusion models, please see diffusion models distillation papers

The following is a sample of images: the left is generated by 2-step distilled model, and the right is generated by the original diffusion model with 4 DDIM steps.

images

Requirements

pip install accelerate einops_exts diffusers datasets transformers
# Download stable-diffusion into third_party
git clone https://github.com/YongfeiYan/diffusion_models_distillation.git
cd diffusion_models_distillation/third_party && git clone https://github.com/CompVis/stable-diffusion.git

Run

Download imagenet and pretrained model

The imagenet data is downloaded through stable-diffusion code. It will download all images into cache dir at the first time to create ImageNetTrain dataset. Use the following to download:

PYTHONPATH=.:third_party/stable-diffusion python diffdstl/data/get_imagenet.py

Download pretrained model:

dst=data/ldm/cin256-v2
mkdir -p $dst && cd $dst
wget https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt 

Finetune v_prediction

The first step is to finetune the original model into v_prediction to stablize distillation process.

# Convert the pretrained model into diffusers pipeline
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/ldm_ckpt_to_pipeline.py configs/imagenet/cin256-v2.yaml data/ldm/cin256-v2/model.ckpt data/test-pipeline
# Finetune
CUDA_VISIBLE_DEVICES=0 bash scripts/progressdstl/finetune_v_prediction.sh

Stage one: classifier-free guidance removal

The second step is to remove classifier-free guidance in sampling:

CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_one.sh

Stage two: distilling to less sampling steps

The third step is to iteratively halve sampling steps. To reduce training time, the script begines with 64 DDIM sampling steps and runs 5 times to distill the student model into 1 sampling step.

# Convert sampling scheduler
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/convert_pipeline_scheduler.py data/log/imagenet/stage_one/pipeline data/log/imagenet/stage_one/pipeline-converted
# Distill
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_two.sh &> stage_two.log & 

Reference

@inproceedings{DBLP:conf/cvpr/MengRGKEHS23,
  author       = {Chenlin Meng and
                  Robin Rombach and
                  Ruiqi Gao and
                  Diederik P. Kingma and
                  Stefano Ermon and
                  Jonathan Ho and
                  Tim Salimans},
  title        = {On Distillation of Guided Diffusion Models},
  booktitle    = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
                  {CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023},
  pages        = {14297--14306},
  publisher    = {{IEEE}},
  year         = {2023},
  url          = {https://doi.org/10.1109/CVPR52729.2023.01374},
  doi          = {10.1109/CVPR52729.2023.01374},
  timestamp    = {Tue, 29 Aug 2023 15:44:40 +0200},
  biburl       = {https://dblp.org/rec/conf/cvpr/MengRGKEHS23.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}
@inproceedings{DBLP:conf/iclr/SalimansH22,
  author       = {Tim Salimans and
                  Jonathan Ho},
  title        = {Progressive Distillation for Fast Sampling of Diffusion Models},
  booktitle    = {The Tenth International Conference on Learning Representations, {ICLR}
                  2022, Virtual Event, April 25-29, 2022},
  publisher    = {OpenReview.net},
  year         = {2022},
  url          = {https://openreview.net/forum?id=TIdIXIpzhoI},
  timestamp    = {Sat, 20 Aug 2022 01:15:42 +0200},
  biburl       = {https://dblp.org/rec/conf/iclr/SalimansH22.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}