PyTorch Implementation of Denoising Diffusion Probabilistic Models [paper] [official repo]
- Original DDPM1 training & sampling
- DDIM2 sampler
- Standard evaluation metrics
- Distributed Data Parallel5 (DDP) multi-GPU training
- torch>=1.12.0
- torchvision>=1.13.0
- scipy>=1.7.3
Toy data | Real-world data | ||
---|---|---|---|
Training | Training | Generation | Evaluation |
Expand
|
Expand
|
Expand
|
Expand
|
Examples
-
Train a 25-Gaussian toy model with single GPU (device id: 0) for a total of 100 epochs
python train_toy.py --dataset gaussian25 --device cuda:0 --epochs 100
-
Train CIFAR-10 model with single GPU (device id: 0) for a total of 50 epochs
python train.py --dataset cifar10 --train-device cuda:0 --epochs 50
(You can always use dry-run
for testing/tuning purpose.)
-
Train a CelebA model with an effective batch size of 64 x 2 x 4 = 128 on a four-card machine (single node) using shared file-system initialization
python train.py --dataset celeba --use-ema --num-accum 2 --num-gpus 4 --distributed --rigid-launch
use-ema
: use exponential moving average (0.9999 decay by default)num-accum 2
: accumulate gradients for 2 mini-batchesnum-gpus
: number of GPU(s) to use for training, i.e.WORLD_SIZE
of the process groupdistributed
: enable multi-gpu DDP trainingrigid-run
: use shared-file system initialization andtorch.multiprocessing
-
(Recommended) Train a CelebA model with an effective batch-size of 64 x 1 x 2 = 128 using only two GPUs with
torchrun
Elastic Launch6 (TCP initialization)export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --distributed
-
Generate 50,000 samples (128 per mini-batch) of the EMA checkpoint located at
./chkpts/train/ddpm_cifar10_2160.pt
in parallel using 4 GPUs and DDIM sampler. The results are stored in./images/eval/cifar10_2160
python generate.py --dataset cifar10 --chkpt-path ./chkpt/train/ddpm_cifar10_2160.pt --use-ema --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _2160 --num-gpus 4
use-ddim
: use DDIMskip-schedule quadratic
: use the quadratic schedulesubseq-size
: length of sub-sequence, i.e. DDIM timestepssuffix
: suffix string to the dataset name in the folder namenum-gpus
: number of GPU(s) to use for generation
-
Evaluate FID, Precision/Recall of generated samples in
./images/eval/cifar10_2160
python eval.py --dataset cifar10 --folder-name cifar10_2160
Dataset | 8 Gaussian | 25 Gaussian | Swiss Roll |
---|---|---|---|
True | |||
Generated |
Table of evaluated metrics
Dataset | FID (↓) | Precision (↑) | Recall (↑) | Training steps | Training loss | Checkpoint |
---|---|---|---|---|---|---|
CIFAR-10 | 9.23 | 0.692 | 0.473 | 46.8k | 0.0302 | - |
|__ | 6.02 | 0.693 | 0.510 | 93.6k | 0.0291 | - |
|__ | 4.04 | 0.701 | 0.550 | 234.0k | 0.0298 | - |
|__ | 3.36 | 0.717 | 0.559 | 468.0k | 0.0284 | - |
|__ | 3.25 | 0.736 | 0.548 | 842.4k | 0.0277 | [Link] |
CelebA | 4.81 | 0.766 | 0.490 | 189.8k | 0.0153 | - |
|__ | 3.88 | 0.760 | 0.516 | 379.7k | 0.0151 | - |
|__ | 3.07 | 0.754 | 0.540 | 949.2k | 0.0147 | [Link] |
Dataset | CIFAR-10 | CelebA | CelebA-HQ |
---|---|---|---|
Generated images |
- Simple Web App empowered by Streamlit: [tqch/diffusion-webapp]
- Classifier-Free Guidance: [tqch/v-diffusion-torch]
Footnotes
-
Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851. ↩
-
Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." International Conference on Learning Representations. 2020. ↩
-
Heusel, Martin, et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." Advances in neural information processing systems 30 (2017). ↩
-
Kynkäänniemi, Tuomas, et al. "Improved precision and recall metric for assessing generative models." Advances in Neural Information Processing Systems 32 (2019). ↩
-
DistributedDataParallel - PyTorch 1.12 Documentation, https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html. ↩
-
Torchrun (Elastic Launch) - PyTorch 1.12 Documentation*, https://pytorch.org/docs/stable/elastic/run.html. ↩