This is the official implementation of DiGIT accepted at NeurIPS 2024. The code will be available soon.
We present DiGIT, an auto-regressive generative model performing next-token prediction in an abstract latent space derived from self-supervised learning (SSL) models. By employing K-Means clustering on the hidden states of the DINOv2 model, we effectively create a novel discrete tokenizer. This method significantly boosts image generation performance on ImageNet dataset, achieving an FID score of 4.59 for class-unconditional tasks and 3.39 for class-conditional tasks. Additionally, the model enhances image understanding, attaining a linear-probe accuracy of 80.3.
Methods | # Tokens | Features | # Params | Top-1 Acc. |
---|---|---|---|---|
iGPT-L | 32 |
1536 | 1362M | 60.3 |
iGPT-XL | 64 |
3072 | 6801M | 68.7 |
VIM+VQGAN | 32 |
1024 | 650M | 61.8 |
VIM+dVAE | 32 |
1024 | 650M | 63.8 |
VIM+ViT-VQGAN | 32 |
1024 | 650M | 65.1 |
VIM+ViT-VQGAN | 32 |
2048 | 1697M | 73.2 |
AIM | 16 |
1536 | 0.6B | 70.5 |
DiGIT (Ours) | 16 |
1024 | 219M | 71.7 |
DiGIT (Ours) | 16 |
1536 | 732M | 80.3 |
Type | Methods | # Param | # Epoch | FID |
IS |
---|---|---|---|---|---|
GAN | BigGAN | 70M | - | 38.6 | 24.70 |
Diff. | LDM | 395M | - | 39.1 | 22.83 |
Diff. | ADM | 554M | - | 26.2 | 39.70 |
MIM | MAGE | 200M | 1600 | 11.1 | 81.17 |
MIM | MAGE | 463M | 1600 | 9.10 | 105.1 |
MIM | MaskGIT | 227M | 300 | 20.7 | 42.08 |
MIM | DiGIT (+MaskGIT) | 219M | 200 | 9.04 | 75.04 |
AR | VQGAN | 214M | 200 | 24.38 | 30.93 |
AR | DiGIT (+VQGAN) | 219M | 400 | 9.13 | 73.85 |
AR | DiGIT (+VQGAN) | 732M | 200 | 4.59 | 141.29 |
Type | Methods | # Param | # Epoch | FID |
IS |
---|---|---|---|---|---|
GAN | BigGAN | 160M | - | 6.95 | 198.2 |
Diff. | ADM | 554M | - | 10.94 | 101.0 |
Diff. | LDM-4 | 400M | - | 10.56 | 103.5 |
Diff. | DiT-XL/2 | 675M | - | 9.62 | 121.50 |
Diff. | L-DiT-7B | 7B | - | 6.09 | 153.32 |
MIM | CQR-Trans | 371M | 300 | 5.45 | 172.6 |
MIM+AR | VAR | 310M | 200 | 4.64 | - |
MIM+AR | VAR | 310M | 200 | 3.60* | 257.5* |
MIM+AR | VAR | 600M | 250 | 2.95* | 306.1* |
MIM | MAGVIT-v2 | 307M | 1080 | 3.65 | 200.5 |
AR | VQVAE-2 | 13.5B | - | 31.11 | 45 |
AR | RQ-Trans | 480M | - | 15.72 | 86.8 |
AR | RQ-Trans | 3.8B | - | 7.55 | 134.0 |
AR | ViTVQGAN | 650M | 360 | 11.20 | 97.2 |
AR | ViTVQGAN | 1.7B | 360 | 5.3 | 149.9 |
MIM | MaskGIT | 227M | 300 | 6.18 | 182.1 |
MIM | DiGIT (+MaskGIT) | 219M | 200 | 4.62 | 146.19 |
AR | VQGAN | 227M | 300 | 18.65 | 80.4 |
AR | DiGIT (+VQGAN) | 219M | 200 | 4.79 | 142.87 |
AR | DiGIT (+VQGAN) | 732M | 200 | 3.39 | 205.96 |
*: VAR is trained with classifier-free guidance while all the other models are not.
- Download the code
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
We use the VQGAN checkpoint realsed by MAGE.
Download the DINOv2 SSL model.
- Install
fairseq
viapip install fairseq
.
Download ImageNet dataset, and place it in your dataset dir $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012
. Prepare the ImageNet validation set for FID evaluation against validation set:
python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012 --output_dir imagenet-val
To evaluate FID against the validation set, do pip install torch-fidelity
.
Extract the features of SSL and dump as the npy files. Apply K-Mneas alogrithm with faiss to get the centroids.
bash preprocess/run.sh
First train a GPT model with discriminative tokenizer and you can find the training scripts in scripts/train_stage1_ar.sh
and the hyper-params are in config/stage1/dino_base.yaml
. You can enable class conditional generation with scripts/train_stage1_classcond.sh
.
Then train a pixel decoder (either AR model or NAR model) conditioned on discriminative tokens and you can find the autoregressive training scripts in scripts/train_stage2_ar.sh
and NAR training scripts in scripts/train_stage2_nar.sh
.
A folder named outputs/EXP_NAME/checkpoints
will be created to save the checkpoints. The tensorboard files will be save at outputs/EXP_NAME/tb
. The outputs/EXP_NAME/train.log
will record logs.
You can monitor the training process by using tensorboard --logdir=outputs/EXP_NAME/tb
.
First sampling discriminative tokens with scripts/infer_stage1_ar.sh
.
Then we sample VQ tokens conditioned on discriminative tokens sampled before with scripts/infer_stage2_ar.sh
.
A folder named outputs/EXP_NAME/results
will be create to save the generated tokens and synthesized images.
python fairseq_user/eval_fid.py --results-path $IMG_SAVE_DIR --subset $GEN_SUBSET
bash scripts/train_stage1_linearprobe.sh
This project is licensed under the MIT License - see the LICENSE file for details.
If you find our project useful, hope you can star our repo and cite our work as follows.
@inproceedings{
digit,
title={Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective},
author={Yongxin Zhu, Bocheng Li, Hang Zhang, Xin Li, Linli Xu, Lidong Bing},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=waQ5X4qc3W}
}
``