/DiT-MoE

Scaling Diffusion Transformers with Mixture of Experts

Primary LanguagePython

Scaling Diffusion Transformers with Mixture of Experts
Official PyTorch Implementation

arXiv

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper scaling Diffusion Transformers to 16 billion parameters (DiT-MoE). DiT-MoE as a sparse version of the diffusion Transformer, is scalable and competitive with dense networks while exhibiting highly optimized inference.

DiT-MoE framework

  • 🪐 A PyTorch implementation of DiT-MoE and pre-trained checkpoints in paper
  • 🌋 Rectified flow-based training and sampling scripts
  • 💥 A sampling script for running pre-trained DiT-MoE
  • 🛸 A DiT-MoE training script using PyTorch DDP and deepspeed
  • ⚡️ A upcycle scripts to convert dense to MoE ckpts referring link

To-do list

  • training / inference scripts
  • experts routing analysis
  • huggingface ckpts

1. Training

You can refer to the link to build the running environment.

To launch DiT-MoE-S/2 (256x256) in the latent space training with N GPUs on one node with pytorch DDP:

torchrun --nnodes=1 --nproc_per_node=N train.py \
--model DiT-S/2 \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--image-size 256 \
--global-batch-size 256 \
--vae-path /path/to/vae

For multiple node training, we solve the bug at original DiT repository, and you can run with 8 nodes as:

torchrun --nnodes=8 \
    --node_rank=0 \
    --nproc_per_node=8 \
    --master_addr="10.0.0.0" \
    --master_port=1234 \
    train.py \
    --model DiT-B/2 \
    --num_experts 8 \
    --num_experts_per_tok 2 \
    --global-batch-size 1024 \
    --data-path /path/to/imagenet/train \
    --vae-path /path/to/vae

For larger model size training, we recommand to use deepspeed with flash attention scripts, and different stage settings including zero2 and zero3 can be seen in config file. You can run as:

python -m torch.distributed.launch --nnodes=1 --nproc_per_node=8 train_deepspeed.py \
--deepspeed_config config/zero2.json \
--model DiT-XL/2 \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--vae-path /path/to/vae \
--train_batch_size 32

For rectified flow training as FLUX and SD3, you can run as:

python -m torch.distributed.launch --nnodes=1 --nproc_per_node=8 train_deepspeed.py \
--deepspeed_config config/zero2.json \
--model DiT-XL/2 \
--rf True \
--num_experts 8 \
--num_experts_per_tok 2 \
--data-path /path/to/imagenet/train \
--vae-path /path/to/vae \
--train_batch_size 32

Our experiments show that rectified flow training leads to a better performance as well as faster convergence.

We also provide all shell scripts for different model size training in file folder scripts.

2. Inference

We include a sample.py script which samples images from a DiT-MoE model. Take care that we use torch.float16 for large model inference.

python sample.py \
--model DiT-XL/2 \
--ckpt /path/to/model \
--vae-path /path/to/vae \
--image-size 256 \
--cfg-scale 1.5

3. Download Models and Data

The model weights, data and used scripts for results reproduce are listed as follows.

We use sd vae in this link.

DiT-MoE Model Image Resolution Url Scripts Loss curve
DiT-MoE-S/2-8E2A 256x256 link DDIM -
DiT-MoE-S/2-16E2A 256x256 link DDIM -
DiT-MoE-B/2-8E2A 256x256 link DDIM -
DiT-MoE-XL/2-8E2A 256x256 link RF -
DiT-MoE-G/2-16E2A 256x256 link RF -

4. Expert Specialization Analysis Tools

We provide all the analysis scripts used in the paper.
You can use expert_data.py to sample data points towards experts ids across different class-conditional.
Then, a series of files headmap_xx.py are used to visualize the frequency of expert selection for different scenarios.
Quick validation can be achieved by adjusting the number of sampled data and the save path.

5. BibTeX

@article{FeiDiTMoE2024,
  title={Scaling Diffusion Transformers to 16 Billion Parameters},
  author={Zhengcong Fei, Mingyuan Fan, Changqian Yu, Debang Li, Jusnshi Huang},
  year={2024},
  journal={arXiv preprint},
}

6. Acknowledgments

The codebase is based on the awesome DiT and DeepSeek-MoE repos.