Paper | Project Page | Run DiT-XL/2
This repo features an improved PyTorch implementation for the paper Scalable Diffusion Models with Transformers.
It contains:
- 🪐 An improved PyTorch implementation and the original implementation of DiT
- ⚡️ Pre-trained class-conditional DiT models trained on ImageNet (512x512 and 256x256)
- 💥 A self-contained Hugging Face Space and Colab notebook for running pre-trained DiT-XL/2 models
- 🛸 An improved DiT training script and several training options
First, download and set up the repo:
git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT
We provide an environment.yml
file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit
and pytorch-cuda
requirements from the file.
conda env create -f environment.yml
conda activate DiT
Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py
. Weights for our pre-trained DiT model will be
automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256
and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
our 512x512 DiT-XL/2 model, you can use:
python sample.py --image-size 512 --seed 1
For convenience, our pre-trained DiT models can be downloaded directly here as well:
DiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
---|---|---|---|---|
XL/2 | 256x256 | 2.27 | 278.24 | 119 |
XL/2 | 512x512 | 3.04 | 240.82 | 525 |
Custom DiT checkpoints. If you've trained a new DiT model with train.py
(see below), you can add the --ckpt
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
256x256 DiT-L/4 model, run:
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
To extract ImageNet features with 1
GPUs on one node:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features
We provide a training script for DiT in train.py
. This script can be used to train class-conditional
DiT models, but it can be easily modified to support other types of conditioning.
To launch DiT-XL/2 (256x256) training with 1
GPUs on one node:
accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
To launch DiT-XL/2 (256x256) training with N
GPUs on one node:
accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
Alternatively, you have the option to extract and train the scripts located in the folder training options.
We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:
DiT Model | Train Steps | FID-50K (JAX Training) |
FID-50K (PyTorch Training) |
PyTorch Global Training Seed |
---|---|---|---|---|
XL/2 | 400K | 19.5 | 18.1 | 42 |
B/4 | 400K | 68.4 | 68.9 | 42 |
B/4 | 400K | 68.4 | 68.3 | 100 |
These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID
here is computed with 250 DDPM sampling steps, with the mse
VAE decoder and without guidance (cfg-scale=1
).
In comparison to the original implementation, we implement a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training, and pre-extracted VAE features, resulting in a 95% speed increase and 60% memory reduction on DiT-XL/2. Some data points using a global batch size of 128 with a A100:
gradient checkpointing | mixed precision training | feature pre-extraction | training speed | memory |
---|---|---|---|---|
❌ | ❌ | ❌ | - | out of memory |
✔ | ❌ | ❌ | 0.43 steps/sec | 44045 MB |
✔ | ✔ | ❌ | 0.56 steps/sec | 40461 MB |
✔ | ✔ | ✔ | 0.84 steps/sec | 27485 MB |
We include a sample_ddp.py
script which samples a large number of images from a DiT model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over N
GPUs, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
There are several additional options; see sample_ddp.py
for details.