/DisBack

The official implementation of Distribution Backtracking Distillation for One-step Diffusion Models

Primary LanguagePython

πŸ”₯DisBackπŸ”₯

by Shengyuan Zhang1, Ling Yang2, Zejian Li*1, An Zhao1, Chenye Meng1, Changyuan Yang3, Guang Yang3, Zhiyuan Yang3, Lingyun Sun1

1Zhejiang University 2Peking University 3Alibaba Group

Abstract

Accelerating the sampling speed of diffusion models remains a significant challenge. Recent score distillation methods distill a heavy teacher model into an efficient student generator, which is optimized by calculating the difference in scores for the samples generated by the student model between the two score functions. However, there is a score mismatch issue in the early stage of the distillation process, because existing methods mainly focus on using the endpoint of pre-trained diffusion models as teacher models, overlooking the importance of the convergence trajectory between the one-step generator and the teacher model. To address this issue, we extend the score distillation process with the entire convergence trajectory of teacher models and propose \textbf{Dis}tribution \textbf{Back}tracking Distillation (\textbf{DisBack}) for distilling one-step generators. DisBask is composed of two stages: \textit{Degradation Recording} and \textit{Distribution Backtracking}. \textit{Degradation Recording} is designed for obtaining the convergence trajectory of teacher models, which obtains the degradation path from the trained teacher model to the untrained initial student. The degradation path implicitly represents the intermediate distributions of teacher models. Then \textit{Distribution Backtracking} trains a student generator to backtrack the intermediate distributions for approximating the convergence trajectory of teacher models. Extensive experiments show that the DisBack achieves faster and better convergence than the existing distillation method and accomplishes comparable generation performance. Notably, DisBack is easy to implement and can be generalized to existing distillation methods to boost performance.

The structure of DisBack

Samples of DisBack

Using DisBack

Environment setup

conda create -n disback python=3.8 -y 
conda activate disback 

pip install --upgrade anyio
pip install -r requirements.txt
python setup.py  develop

Inference

The distilled SDXL model is already uploaded on HuggingFace

One-step text-to-image generation

import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download

base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "SYZhang0805/DisBack"
ckpt_name = "SDXL_DisBack.bin"

unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))

pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
prompt="A photo of a dog." 
image=pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[399], height=1024, width=1024).images[0]
image.save('output.png', 'PNG')

Training and Evaluation

In text-to-image scenario, DisBack is trained based on the DMD2. The pre-trained DMD2 model can be downloaded from here.

Download Base Diffusion Models and Training Data

export CHECKPOINT_PATH="" # change this to your own checkpoint folder (this should be a central directory shared across nodes)
export WANDB_ENTITY="" # change this to your own wandb entity
export WANDB_PROJECT="" # change this to your own wandb project
export MASTER_IP=""  # change this to your own master ip

# Not sure why but we found the following line necessary to work with the accelerate package in our system. 
# Change YOUR_MASTER_IP/YOUR_MASTER_NODE_NAME to the correct value 
echo "YOUR_MASTER_IP 	YOUR_MASTER_NODE_NAME" | sudo tee -a /etc/hosts

# create a fsdp configs for accelerate launch. change the EXP_NAME to your own experiment name 
python main/sdxl/create_sdxl_fsdp_configs.py --folder fsdp_configs/EXP_NAME  --master_ip $MASTER_IP --num_machines 8  --sharding_strategy 4

mkdir $CHECKPOINT_PATH
mkdir $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000/

bash scripts/download_sdxl.sh $CHECKPOINT_PATH

Degradation recording stage

bash experiments/sdxl/degradation_sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh

The degradation path is saved as follows:

path_directory/
β”‚
β”œβ”€β”€ checkpoint_model_024100/             
β”‚   └── pytorch_model_1.bin          
β”œβ”€β”€ checkpoint_model_024200/             
β”‚   └── pytorch_model_1.bin      
...

Distribution backtracking stage

bash experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh

Compatibility of other models

DisBack can be applied to the score distillation process using the following pseudocode.

# Degradation
s_theta = UNet() # Pre-trained Diffusion Model
s_theta_prime, G_stu = s_theta.clone(), s_theta.clone() # initialize generator and the beginning of the degradation path.
path_degradation = []
for idx in range(num_iter_1st_stage):
	x_0 = one_step_sample(G_stu)
	x_t, t, epsilon = addnoise(x_0)
	ckpt = train_score_model(s_theta_prime, x_t, t, epsilon) # Training strategy depends on the type of pre-trained model used. Eq.(7) in the paper.
	if idx // interval_1st == 0:
		path_degradation.append(ckpt) # Add intermediate checkpoint to the degradation path.
else:
	path_degradation.append(ckpt)

# Backtracking 
path_backtracking = path_degradation[::-1] # The reverse of the degradation path is viewed as the convergence trajectory.
s_phi = path_backtracking[0].clone() # Use the first checkpoint of the convergence trajectory as the initial s_phi.
target = 1
for idx in range(num_iter_2nd_stage):
	s_target = path_backtracking[target]
	x_0 = one_step_sample(G_stu) # One step generation.
	x_t, t, epsilon = addnoise(x_0)
	x_t.bachward( s_phi(x_t,t) - s_target(x_t,t) ) # VSD loss. Eq.(8) in the paper.
	update(G_stu) # Optimize G by gradient descent.
	train_score_model(s_phi, x_t, t, epsilon) # Eq.(5) in the paper.
	if idx // interval_2nd == 0 and idx>1: # Switch the target.
		target += 1 

Citation

If you find our paper useful or relevant to your research, please kindly cite our papers:

@article{zhang2024distributionbacktrackingbuildsfaster,
    title={Distribution Backtracking Builds A Faster Convergence Trajectory for One-step Diffusion Distillation}, 
    author={Shengyuan Zhang and Ling Yang and Zejian Li and An Zhao and Chenye Meng and Changyuan Yang and Guang Yang and Zhiyuan Yang and Lingyun Sun},
    journal={arXiv 2408.15991},
    year={2024}
}

Credits

DisBack is highly built on the following amazing open-source projects:

DMD2: Improved Distribution Matching Distillation for Fast Image Synthesis

Diff-Instruct: Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models

ScoreGAN: Unifying GANs and Score-Based Diffusion as Generative Particle Models

Thanks to the maintainers of these projects for their contribution to this project!