/v-diffusion-torch

PyTorch Implementation of V-objective Diffusion Probabilistic Models with Classifier-free Guidance

Primary LanguagePythonMIT LicenseMIT

banner

PyTorch Implementation of V-objective Diffusion Probabilistic Model (VDPM) and more

Key features

  • improved UNet design (conditioning, resampling, etc.) 1
  • continuous-time training on log-SNR schedule 2
  • DDIM sampler 3
  • MSE loss reweighting (constant, SNR, truncated-SNR) 4
  • velocity prediction 4
  • classifier-free guidance 5
  • others:
    • distributed data parallel (multi-gpu training)
    • gradient accumulation
    • FID/Precision/Recall evaluation

Basic usage

expand

usage: train.py [-h] [--dataset {mnist,cifar10,celeba}] [--root ROOT]        
                [--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]  
                [--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE]      
                [--num-accum NUM_ACCUM] [--train-timesteps TRAIN_TIMESTEPS]  
                [--sample-timesteps SAMPLE_TIMESTEPS]                        
                [--logsnr-schedule {linear,sigmoid,cosine,legacy}]           
                [--logsnr-max LOGSNR_MAX] [--logsnr-min LOGSNR_MIN]          
                [--model-out-type {x_0,eps,both,v}]                          
                [--model-var-type {fixed_small,fixed_large,fixed_medium}]    
                [--reweight-type {constant,snr,truncated_snr,alpha2}]        
                [--loss-type {kl,mse}] [--intp-frac INTP_FRAC] [--use-cfg]   
                [--w-guide W_GUIDE] [--p-uncond P_UNCOND]                    
                [--num-workers NUM_WORKERS] [--train-device TRAIN_DEVICE]    
                [--eval-device EVAL_DEVICE] [--image-dir IMAGE_DIR]          
                [--image-intv IMAGE_INTV] [--num-save-images NUM_SAVE_IMAGES]
                [--sample-bsz SAMPLE_BSZ] [--config-dir CONFIG_DIR]          
                [--ckpt-dir CHKPT_DIR] [--ckpt-name CHKPT_NAME]            
                [--ckpt-intv CHKPT_INTV] [--seed SEED] [--resume] [--eval]  
                [--use-ema] [--use-ddim] [--ema-decay EMA_DECAY]             
                [--distributed]                                    
optional arguments:                                                          
  -h, --help            show this help message and exit                      
  --dataset {mnist,cifar10,celeba}                                           
  --root ROOT           root directory of datasets                           
  --epochs EPOCHS       total number of training epochs                      
  --lr LR               learning rate                                        
  --beta1 BETA1         beta_1 in Adam                                       
  --beta2 BETA2         beta_2 in Adamffusion-torch> ^C
  --weight-decay WEIGHT_DECAYects\v-diffusion-torch> ^C
                        decoupled weight_decay factor in Adamrain.py --help
  --batch-size BATCH_SIZE
  --num-accum NUM_ACCUM
                        number of batches before weight update, a.k.a.
                        gradient accumulation
  --train-timesteps TRAIN_TIMESTEPS
                        number of diffusion steps for training (0 indicates
                        continuous training)
  --sample-timesteps SAMPLE_TIMESTEPS
                        number of diffusion steps for sampling
  --logsnr-schedule {linear,sigmoid,cosine,legacy}
  --logsnr-max LOGSNR_MAX
  --logsnr-min LOGSNR_MIN
  --model-out-type {x_0,eps,both,v}
  --model-var-type {fixed_small,fixed_large,fixed_medium}
  --reweight-type {constant,snr,truncated_snr,alpha2}
  --loss-type {kl,mse}
  --intp-frac INTP_FRAC
  --use-cfg             whether to use classifier-free guidance
  --w-guide W_GUIDE     classifier-free guidance strength
  --p-uncond P_UNCOND   probability of unconditional training
  --num-workers NUM_WORKERS
                        number of workers for data loading
  --train-device TRAIN_DEVICE
  --eval-device EVAL_DEVICE
  --image-dir IMAGE_DIR
  --image-intv IMAGE_INTV
  --num-save-images NUM_SAVE_IMAGES
                        number of images to generate & save
  --sample-bsz SAMPLE_BSZ
                        batch size for sampling
  --config-dir CONFIG_DIR
  --ckpt-dir CHKPT_DIR
  --ckpt-name CHKPT_NAME
  --ckpt-intv CHKPT_INTV
                        frequency of saving a checkpoint
  --seed SEED           random seed
  --resume              to resume training from a checkpoint
  --eval                whether to evaluate fid during training

Examples

# train cifar10 with one gpu
python train.py --dataset cifar10 --use-ema --use-ddim --num-save-images 80 --use-cfg --epochs 600 --ckpt-intv 120 --image-intv 10

# train cifar10 with two gpus
python -m torch.distributed.run --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset cifar10 --use-ema --use-ddim --num-save-images 80 --use-cfg --epochs 600 --ckpt-intv 120 --image-intv10 --distributed

# train celeba with one gpu with effective batch_size 128
python train.py --dataset celeba --use-ema --use-ddim --num-save-images 64 --use-cfg --epochs 240 --ckpt-intv 120 --image-intv 10 --num-accum 8 --sample-bsz 32

# train celebA with two gpus
python -m torch.distributed.run --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --use-ema --use-ddim --num-save-images 64 --use-cfg --epochs 240 --ckpt-intv 120 --image-intv 10 --distributed --num-accum 4 --sample-bsz 32

Conditional generation

CIFAR-10

guidance strength class images
w=0
FID:2.58
IS:9.76
airplanes w=0
cars
birds
cats
deer
dogs
frogs
horses
ships
trucks
w=0.1
FID:3.12
IS:10.01
airplanes w=0.1
cars
birds
cats
deer
dogs
frogs
horses
ships
trucks
w=1
FID:21.35
IS:9.92
airplanes w=1
cars
birds
cats
deer
dogs
frogs
horses
ships
trucks

CelebA

guidance strength tag Black_Hair Blond_Hair Brown_Hair Gray_Hair
w=0 Receding_Hairline w=0
Straight_Hair
Wavy_Hair
Bald
Bangs
w=1 Receding_Hairline w=1
Straight_Hair
Wavy_Hair
Bald
Bangs
w=3 Receding_Hairline w=3
Straight_Hair
Wavy_Hair
Bald
Bangs

More variants (animated)

guidance strength tag Black_Hair Blond_Hair Brown_Hair Gray_Hair
w=0 Receding_Hairline w=0
Straight_Hair
Wavy_Hair
Bald
Bangs
w=1 Receding_Hairline w=1
Straight_Hair
Wavy_Hair
Bald
Bangs
w=3 Receding_Hairline w=3
Straight_Hair
Wavy_Hair
Bald
Bangs

Acknowledgement

The development of this codebase is largely based on the official JAX implementation open-sourced by Google Research and my previous PyTorch implementation of DDPM, which are available at [google-research/diffusion_distillation] and [tqch/ddpm-torch] respectively.

References

Footnotes

  1. Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

  2. Kingma, Diederik, Tim Salimans, Ben Poole, and Jonathan Ho. "Variational diffusion models." Advances in neural information processing systems 34 (2021): 21696-21707.

  3. Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020).

  4. Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." arXiv preprint arXiv:2202.00512 (2022). 2

  5. Ho, Jonathan, Tim Salimans. ‘Classifier-Free Diffusion Guidance’. NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021. https://openreview.net/forum?id=qw8AKxfYbI.