/Min-SNR-Diffusion-Training

[ICCV 2023] Efficient Diffusion Training via Min-SNR Weighting Strategy

Primary LanguagePython

Efficient Diffusion Training via Min-SNR Weighting Strategy

By Tiankai Hang, Shuyang Gu, Chen Li, Jianmin Bao, Dong Chen, Han Hu, Xin Geng, Baining Guo.

Paper | Code

Abstract. Denoising diffusion models have been a mainstream approach for image generation, however, training these models often suffers from slow convergence. In this paper, we discovered that the slow convergence is partly due to conflicting optimization directions between timesteps. To address this issue, we treat the diffusion training as a multi-task learning problem, and introduce a simple yet effective approach referred to as Min-SNR-$\gamma$. This method adapts loss weights of timesteps based on clamped signal-to-noise ratios, which effectively balances the conflicts among timesteps. Our results demonstrate a significant improvement in converging speed, 3.4x faster than previous weighting strategies. It is also more effective, achieving a new record FID score of 2.06 on the ImageNet 256x256 benchmark using smaller architectures than that employed in previous state-of-the-art.

Data Preparation

For CelebA dataset, we follow ScoreSDE to process the data.

For ImageNet dataset, we download it from the official website. For ImageNet-64, we did not adopt offline pre-processing. For ImageNet-256, we cropped the images to 256x256 and compressed them using AutoencoderKL from Diffusers. The compressed latent codes are treated equally as images, except the file extension.

Training

For training with ViT-B model, you should first put the downloaded/processed data above to some path, and set DATA_DIR in the config file vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh. Then you could run like

GPUS=8
BATCH_SIZE_PER_GPU=32
bash configs/in256/vit-b_layer12_lr1e-4_099_099_pred_x0__min_snr_5__fp16_bs8x32.sh $GPUS $BATCH_SIZE_PER_GPU

Sampling with Pre-trained Models

For sampling for ImageNet-256, you could directly run

bash configs/in256/inference.sh

For sampling for ImageNet-64, you could directly run

bash configs/in64/inference.sh

Here we use 8 GPUs for sampling. You can change GPUS=8 to GPUS=1 for single GPU evaluation in configs/in256/inference.sh The pre-trained models will be automatically downloaded and FID-50K will be calculated.

Citing Min-SNR Diffusion Training

If you find our work useful for your research, please consider citing our paper. 😊

@article{hang2023efficient,
      title={Efficient Diffusion Training via Min-SNR Weighting Strategy}, 
      author={Tiankai Hang and Shuyang Gu and Chen Li and Jianmin Bao and Dong Chen and Han Hu and Xin Geng and Baining Guo},
      year={2023},
      eprint={2303.09556},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowlegements

This repository is based on openai/guided-diffusion. We adopt the implementation for sampling and FID evaluation using NVlabs/edm.