/diffusion-transformer

Implementation of Diffusion Transformer Model in Pytorch

Primary LanguagePythonMIT LicenseMIT

Diffusion Transformer

Implementation of the Diffusion Transformer model in the paper:

Scalable Diffusion Models with Transformers.

See here for the official Pytorch implementation.

Dependencies

  • Python 3.9
  • Pytorch 2.1.1

Training Diffusion Transformer

Use --data_dir=<data_dir> to specify the dataset path.

python train.py --data_dir=./data/

Samples

Sample output from minDiT (39.89M parameters) on CIFAR-10:

Sample output from minDiT on CelebA:

Hparams setting

Adjust hyperparameters in the config.py file.

Implementation notes:

  • minDiT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
  • minDiT largely follows the original DiT model.
  • DiT Block with adaLN-Zero.
  • Diffusion Transformer with Linformer attention.
  • EDM sampler.
  • FID evaluation.

todo

  • Add Classifier-Free Diffusion Guidance and conditional pipeline.
  • Add Latent Diffusion and Autoencoder training.
  • Add generate.py file.

Licence

MIT