This repository shows an implementation of distilling diffusion models into fewer sampling steps based on On Distillation of Guided Diffusion Models and Progressive Distillation for Fast Sampling of Diffusion Models. The implementation is based on diffusers, and distills a classifier-free guidance model on imagenet into 1/2 sampling steps.
For more distillation papers about diffusion models, please see diffusion models distillation papers
The following is a sample of images: the left is generated by 2-step distilled model, and the right is generated by the original diffusion model with 4 DDIM steps.
pip install accelerate einops_exts diffusers datasets transformers
# Download stable-diffusion into third_party
git clone https://github.com/YongfeiYan/diffusion_models_distillation.git
cd diffusion_models_distillation/third_party && git clone https://github.com/CompVis/stable-diffusion.git
The imagenet data is downloaded through stable-diffusion code. It will download all images into cache dir at the first time to create ImageNetTrain dataset. Use the following to download:
PYTHONPATH=.:third_party/stable-diffusion python diffdstl/data/get_imagenet.py
Download pretrained model:
dst=data/ldm/cin256-v2
mkdir -p $dst && cd $dst
wget https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt
The first step is to finetune the original model into v_prediction to stablize distillation process.
# Convert the pretrained model into diffusers pipeline
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/ldm_ckpt_to_pipeline.py configs/imagenet/cin256-v2.yaml data/ldm/cin256-v2/model.ckpt data/test-pipeline
# Finetune
CUDA_VISIBLE_DEVICES=0 bash scripts/progressdstl/finetune_v_prediction.sh
The second step is to remove classifier-free guidance in sampling:
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_one.sh
The third step is to iteratively halve sampling steps. To reduce training time, the script begines with 64 DDIM sampling steps and runs 5 times to distill the student model into 1 sampling step.
# Convert sampling scheduler
PYTHONPATH=.:third_party/stable-diffusion/ python scripts/progressdstl/convert_pipeline_scheduler.py data/log/imagenet/stage_one/pipeline data/log/imagenet/stage_one/pipeline-converted
# Distill
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/progressdstl/stage_two.sh &> stage_two.log &
@inproceedings{DBLP:conf/cvpr/MengRGKEHS23,
author = {Chenlin Meng and
Robin Rombach and
Ruiqi Gao and
Diederik P. Kingma and
Stefano Ermon and
Jonathan Ho and
Tim Salimans},
title = {On Distillation of Guided Diffusion Models},
booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
{CVPR} 2023, Vancouver, BC, Canada, June 17-24, 2023},
pages = {14297--14306},
publisher = {{IEEE}},
year = {2023},
url = {https://doi.org/10.1109/CVPR52729.2023.01374},
doi = {10.1109/CVPR52729.2023.01374},
timestamp = {Tue, 29 Aug 2023 15:44:40 +0200},
biburl = {https://dblp.org/rec/conf/cvpr/MengRGKEHS23.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@inproceedings{DBLP:conf/iclr/SalimansH22,
author = {Tim Salimans and
Jonathan Ho},
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
booktitle = {The Tenth International Conference on Learning Representations, {ICLR}
2022, Virtual Event, April 25-29, 2022},
publisher = {OpenReview.net},
year = {2022},
url = {https://openreview.net/forum?id=TIdIXIpzhoI},
timestamp = {Sat, 20 Aug 2022 01:15:42 +0200},
biburl = {https://dblp.org/rec/conf/iclr/SalimansH22.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}