This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper scaling Diffusion Transformers to 16 billion parameters (DiT-MoE). DiT-MoE as a sparse version of the diffusion Transformer, is scalable and competitive with dense networks while exhibiting highly optimized inference.
- 🪐 A PyTorch implementation of DiT-MoE and pre-trained checkpoints in paper
- 🌋 Rectified flow-based training and sampling scripts
- 💥 A sampling script for running pre-trained DiT-MoE
- 🛸 A DiT-MoE training script using PyTorch DDP and deepspeed
- ⚡️ A upcycle scripts to convert dense to MoE ckpts referring link
- training / inference scripts
- experts routing analysis
- huggingface ckpts
You can refer to the link to build the running environment.
To launch DiT-MoE-S/2 (256x256) in the latent space training with N
GPUs on one node with pytorch DDP:
torchrun --nnodes=1 --nproc_per_node=N train.py \
--model DiT-S/2 \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--image-size 256 \
--global-batch-size 256 \
--vae-path /path/to/vae
For multiple node training, we solve the bug at original DiT repository, and you can run with 8 nodes as:
torchrun --nnodes=8 \
--node_rank=0 \
--nproc_per_node=8 \
--master_addr="10.0.0.0" \
--master_port=1234 \
train.py \
--model DiT-B/2 \
--num_experts 8 \
--num_experts_per_tok 2 \
--global-batch-size 1024 \
--data-path /path/to/imagenet/train \
--vae-path /path/to/vae
For larger model size training, we recommand to use deepspeed with flash attention scripts, and different stage settings including zero2 and zero3 can be seen in config file. You can run as:
python -m torch.distributed.launch --nnodes=1 --nproc_per_node=8 train_deepspeed.py \
--deepspeed_config config/zero2.json \
--model DiT-XL/2 \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--vae-path /path/to/vae \
--train_batch_size 32
For rectified flow training as FLUX and SD3, you can run as:
python -m torch.distributed.launch --nnodes=1 --nproc_per_node=8 train_deepspeed.py \
--deepspeed_config config/zero2.json \
--model DiT-XL/2 \
--rf True \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--vae-path /path/to/vae \
--train_batch_size 32
Our experiments show that rectified flow training leads to a better performance as well as faster convergence.
We also provide all shell scripts for different model size training in file folder scripts.
We include a sample.py
script which samples images from a DiT-MoE model. Take care that we use torch.float16 for large model inference.
python sample.py \
--model DiT-XL/2 \
--ckpt /path/to/model \
--vae-path /path/to/vae \
--image-size 256 \
--cfg-scale 1.5
The model weights, data and used scripts for results reproduce are listed as follows.
We use sd vae in this link.
DiT-MoE Model | Image Resolution | Url | Scripts | Loss curve |
---|---|---|---|---|
DiT-MoE-S/2-8E2A | 256x256 | link | DDIM | - |
DiT-MoE-S/2-16E2A | 256x256 | link | DDIM | - |
DiT-MoE-B/2-8E2A | 256x256 | link | DDIM | - |
DiT-MoE-XL/2-8E2A | 256x256 | link | RF | - |
DiT-MoE-G/2-16E2A | 256x256 | link | RF | - |
We provide all the analysis scripts used in the paper.
You can use expert_data.py
to sample data points towards experts ids across different class-conditional.
Then, a series of files headmap_xx.py are used to visualize the frequency of expert selection for different scenarios.
Quick validation can be achieved by adjusting the number of sampled data and the save path.
@article{FeiDiTMoE2024,
title={Scaling Diffusion Transformers to 16 Billion Parameters},
author={Zhengcong Fei, Mingyuan Fan, Changqian Yu, Debang Li, Jusnshi Huang},
year={2024},
journal={arXiv preprint},
}
The codebase is based on the awesome DiT and DeepSeek-MoE repos.