Distribution Backtracking Builds A Faster Convergence Trajectory for One-step Diffusion Distillation
by Shengyuan Zhang1, Ling Yang2, Zejian Li*1, An Zhao1, Chenye Meng1, Changyuan Yang3, Guang Yang3, Zhiyuan Yang3, Lingyun Sun1
1Zhejiang University 2Peking University 3Alibaba Group
Accelerating the sampling speed of diffusion models remains a significant challenge. Recent score distillation methods distill a heavy teacher model into an efficient student generator, which is optimized by calculating the difference in scores for the samples generated by the student model between the two score functions. However, there is a score mismatch issue in the early stage of the distillation process, because existing methods mainly focus on using the endpoint of pre-trained diffusion models as teacher models, overlooking the importance of the convergence trajectory between the one-step generator and the teacher model. To address this issue, we extend the score distillation process with the entire convergence trajectory of teacher models and propose \textbf{Dis}tribution \textbf{Back}tracking Distillation (\textbf{DisBack}) for distilling one-step generators. DisBask is composed of two stages: \textit{Degradation Recording} and \textit{Distribution Backtracking}. \textit{Degradation Recording} is designed for obtaining the convergence trajectory of teacher models, which obtains the degradation path from the trained teacher model to the untrained initial student. The degradation path implicitly represents the intermediate distributions of teacher models. Then \textit{Distribution Backtracking} trains a student generator to backtrack the intermediate distributions for approximating the convergence trajectory of teacher models. Extensive experiments show that the DisBack achieves faster and better convergence than the existing distillation method and accomplishes comparable generation performance. Notably, DisBack is easy to implement and can be generalized to existing distillation methods to boost performance.
conda create -n disback python=3.8 -y
conda activate disback
pip install --upgrade anyio
pip install -r requirements.txt
python setup.py develop
The distilled SDXL model is already uploaded on HuggingFace
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "SYZhang0805/DisBack"
ckpt_name = "SDXL_DisBack.bin"
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
prompt="A photo of a dog."
image=pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[399], height=1024, width=1024).images[0]
image.save('output.png', 'PNG')
In text-to-image scenario, DisBack is trained based on the DMD2. The pre-trained DMD2 model can be downloaded from here.
export CHECKPOINT_PATH="" # change this to your own checkpoint folder (this should be a central directory shared across nodes)
export WANDB_ENTITY="" # change this to your own wandb entity
export WANDB_PROJECT="" # change this to your own wandb project
export MASTER_IP="" # change this to your own master ip
# Not sure why but we found the following line necessary to work with the accelerate package in our system.
# Change YOUR_MASTER_IP/YOUR_MASTER_NODE_NAME to the correct value
echo "YOUR_MASTER_IP YOUR_MASTER_NODE_NAME" | sudo tee -a /etc/hosts
# create a fsdp configs for accelerate launch. change the EXP_NAME to your own experiment name
python main/sdxl/create_sdxl_fsdp_configs.py --folder fsdp_configs/EXP_NAME --master_ip $MASTER_IP --num_machines 8 --sharding_strategy 4
mkdir $CHECKPOINT_PATH
mkdir $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000/
bash scripts/download_sdxl.sh $CHECKPOINT_PATH
bash experiments/sdxl/degradation_sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh
The degradation path is saved as follows:
path_directory/
β
βββ checkpoint_model_024100/
β βββ pytorch_model_1.bin
βββ checkpoint_model_024200/
β βββ pytorch_model_1.bin
...
bash experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh
DisBack can be applied to the score distillation process using the following pseudocode.
# Degradation
s_theta = UNet() # Pre-trained Diffusion Model
s_theta_prime, G_stu = s_theta.clone(), s_theta.clone() # initialize generator and the beginning of the degradation path.
path_degradation = []
for idx in range(num_iter_1st_stage):
x_0 = one_step_sample(G_stu)
x_t, t, epsilon = addnoise(x_0)
ckpt = train_score_model(s_theta_prime, x_t, t, epsilon) # Training strategy depends on the type of pre-trained model used. Eq.(7) in the paper.
if idx // interval_1st == 0:
path_degradation.append(ckpt) # Add intermediate checkpoint to the degradation path.
else:
path_degradation.append(ckpt)
# Backtracking
path_backtracking = path_degradation[::-1] # The reverse of the degradation path is viewed as the convergence trajectory.
s_phi = path_backtracking[0].clone() # Use the first checkpoint of the convergence trajectory as the initial s_phi.
target = 1
for idx in range(num_iter_2nd_stage):
s_target = path_backtracking[target]
x_0 = one_step_sample(G_stu) # One step generation.
x_t, t, epsilon = addnoise(x_0)
x_t.bachward( s_phi(x_t,t) - s_target(x_t,t) ) # VSD loss. Eq.(8) in the paper.
update(G_stu) # Optimize G by gradient descent.
train_score_model(s_phi, x_t, t, epsilon) # Eq.(5) in the paper.
if idx // interval_2nd == 0 and idx>1: # Switch the target.
target += 1
If you find our paper useful or relevant to your research, please kindly cite our papers:
@article{zhang2024distributionbacktrackingbuildsfaster,
title={Distribution Backtracking Builds A Faster Convergence Trajectory for One-step Diffusion Distillation},
author={Shengyuan Zhang and Ling Yang and Zejian Li and An Zhao and Chenye Meng and Changyuan Yang and Guang Yang and Zhiyuan Yang and Lingyun Sun},
journal={arXiv 2408.15991},
year={2024}
}
DisBack is highly built on the following amazing open-source projects:
DMD2: Improved Distribution Matching Distillation for Fast Image Synthesis
Diff-Instruct: Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models
ScoreGAN: Unifying GANs and Score-Based Diffusion as Generative Particle Models
Thanks to the maintainers of these projects for their contribution to this project!