/PixArt-alpha

Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis

Primary LanguagePythonGNU Affero General Public License v3.0AGPL-3.0

👉 PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis

Huggingface PixArt-alphaProject page PixArt-alphaarXiv


This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Fast training diffusion models with transformers. You can find more visualizations on our project page.

PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis
Junsong Chen*, Jincheng Yu*, Chongjian Ge*, Lewei Yao*, Enze Xie†, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, Zhenguo Li
Huawei Noah’s Ark Lab, Dalian University of Technology, HKU, HKUST


🚩 New Features/Updates

  • ✅ Oct. 27, 2023. Release the training & feature extraction code.
  • ✅ Oct. 20, 2023. Collaborate with Huggingface & Diffusers team to co-release the code and weights. (plz stay tuned.)
  • ✅ Oct. 15, 2023. Release the inference code.

🐱 Abstract

TL; DR: PixArt-α is a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), and the training speed markedly surpasses existing large-scale T2I models, e.g., PixArt-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days).

CLICK for the full abstract The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PixArt-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PixArt-α's training speed markedly surpasses existing large-scale T2I models, e.g., PixArt-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PixArt-α excels in image quality, artistry, and semantic control. We hope PixArt-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.

A small cactus with a happy face in the Sahara desert.


🔥🔥🔥 Why PixArt-α?

Training Efficiency

PixArt-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Training Efficiency.

Method Type #Params #Images A100 GPU days
DALL·E Diff 12.0B 1.54B
GLIDE Diff 5.0B 5.94B
LDM Diff 1.4B 0.27B
DALL·E 2 Diff 6.5B 5.63B 41,66
SDv1.5 Diff 0.9B 3.16B 6,250
GigaGAN GAN 0.9B 0.98B 4,783
Imagen Diff 3.0B 15.36B 7,132
RAPHAEL Diff 3.0B 5.0B 60,000
PixArt-α Diff 0.6B 0.025B 675

High-quality Generation from PixArt-α

  • More samples

🔧 Dependencies and Installation

conda create -n pixart python==3.9.0
conda activate pixart
cd path/to/pixart
pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt

⏬ Download Models

All models will be automatically downloaded. You can also choose to download manually from this url.

Model #Params url
T5 4.3B T5
VAE 80M VAE
PixArt-α-SAM-256 0.6B 256
PixArt-α-512 0.6B 512
PixArt-α-1024 0.6B 1024

🔥 How to Train

Here we take SAM dataset training config as an example, but of course, you can also prepare your own dataset following this method.

You ONLY need to change the config file in config and dataloader in dataset.

python -m torch.distributed.launch --nproc_per_node=2 --master_port=12345 scripts/train.py configs/pixart_config/PixArt_xl2_img256_SAM.py --work-dir output/train_SAM_256

The directory structure for SAM dataset is:

cd ./data

SA1B
├──images/
│  ├──sa_xxxxx.jpg
│  ├──sa_xxxxx.jpg
│  ├──......
├──partition/
│  ├──part0.txt
│  ├──part1.txt
│  ├──......
├──caption_feature_wmask/
│  ├──sa_xxxxx.npz
│  ├──sa_xxxxx.npz
│  ├──......
├──img_vae_feature/
│  ├──train_vae_256/
│  │  ├──noflip/
│  │  │  ├──sa_xxxxx.npy
│  │  │  ├──sa_xxxxx.npy
│  │  │  ├──......

💻 How to Test

Inference requires at least 23GB of GPU memory. Currently support:

Quick start with Gradio

To get started, first install the required dependencies, then run:

python scripts/interface.py --model_path path/to/model.pth --image_size=1024 --port=12345

Let's have a look at a simple example using the http://your-server-ip:port.

Online Demo Huggingface PixArt

Training Efficiency.

🔥To-Do List

  • Inference code
  • T2ICompBench code
  • Training code
  • T5 & VAE feature extraction code
  • Model zoo
  • Diffusers version

📖BibTeX

@misc{chen2023pixartalpha,
      title={PixArt-$\alpha$: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis}, 
      author={Junsong Chen and Jincheng Yu and Chongjian Ge and Lewei Yao and Enze Xie and Yue Wu and Zhongdao Wang and James Kwok and Ping Luo and Huchuan Lu and Zhenguo Li},
      year={2023},
      eprint={2310.00426},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

🤗Acknowledgements

  • Thanks to DiT for their wonderful work and codebase.