/1d-tokenizer

This repo contains the code for our paper An Image is Worth 32 Tokens for Reconstruction and Generation

Primary LanguageJupyter NotebookApache License 2.0Apache-2.0

(NeurIPS 2024) Compact and Mighty - Image Tokenization with Only 32 Tokens for both Reconstruction and Generation!

demo   Website   paper  

We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., 410 × faster than DiT-XL/2) while obtaining a competitive generation quality.

teaser

teaser

Updates

  • 09/25/2024: TiTok is accepted by NeurIPS 2024.
  • 09/11/2024: Release the training codes of generator based on TiTok.
  • 08/28/2024: Release the training codes of TiTok.
  • 08/09/2024: Better support on loading pretrained weights from huggingface models, thanks for the help from @NielsRogge
  • 07/03/2024: Evaluation scripts for reproducing the results reported in the paper, checkpoints of TiTok-B64 and TiTok-S128 are available.
  • 06/21/2024: Demo code and TiTok-L-32 checkpoints release.
  • 06/11/2024: The tech report of this project is available.

🚀 Contributions

We introduce a novel 1D image tokenization framework that breaks grid constraints existing in 2D tokenization methods, leading to a much more flexible and compact image latent representation.

The proposed 1D tokenizer can tokenize a 256 × 256 image into as few as 32 discrete tokens, leading to a signigicant speed-up (hundreds times faster than diffusion models) in generation process, while maintaining state-of-the-art generation quality.

We conduct a series of experiments to probe the properties of rarely studied 1D image tokenization, paving the path towards compact latent space for efficient and effective image representation.

Model Zoo

Dataset Model Link FID
ImageNet TiTok-L-32 Tokenizer checkpoint 2.21 (reconstruction)
ImageNet TiTok-B-64 Tokenizer checkpoint 1.70 (reconstruction)
ImageNet TiTok-S-128 Tokenizer checkpoint 1.71 (reconstruction)
ImageNet TiTok-L-32 Generator checkpoint 2.77 (generation)
ImageNet TiTok-B-64 Generator checkpoint 2.48 (generation)
ImageNet TiTok-S-128 Generator checkpoint 1.97 (generation)

Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.

Installation

pip3 install -r requirements.txt

Get Started

import torch
from PIL import Image
import numpy as np
import demo_util
from huggingface_hub import hf_hub_download
from modeling.maskgit import ImageBert
from modeling.titok import TiTok

titok_tokenizer = TiTok.from_pretrained("yucornetto/tokenizer_titok_l32_imagenet")
titok_tokenizer.eval()
titok_tokenizer.requires_grad_(False)
titok_generator = ImageBert.from_pretrained("yucornetto/generator_titok_l32_imagenet")
titok_generator.eval()
titok_generator.requires_grad_(False)

# or alternatively, downloads from hf
# hf_hub_download(repo_id="fun-research/TiTok", filename="tokenizer_titok_l32.bin", local_dir="./")
# hf_hub_download(repo_id="fun-research/TiTok", filename="generator_titok_l32.bin", local_dir="./")

# load config
# config = demo_util.get_config("configs/titok_l32.yaml")
# titok_tokenizer = demo_util.get_titok_tokenizer(config)
# titok_generator = demo_util.get_titok_generator(config)

device = "cuda"
titok_tokenizer = titok_tokenizer.to(device)
titok_generator = titok_generator.to(device)

# reconstruct an image. I.e., image -> 32 tokens -> image
img_path = "assets/ILSVRC2012_val_00010240.png"
image = torch.from_numpy(np.array(Image.open(img_path)).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0
# tokenization
encoded_tokens = titok_tokenizer.encode(image.to(device))[1]["min_encoding_indices"]
# image assets/ILSVRC2012_val_00010240.png is encoded into tokens tensor([[[ 887, 3979,  349,  720, 2809, 2743, 2101,  603, 2205, 1508, 1891, 4015, 1317, 2956, 3774, 2296,  484, 2612, 3472, 2330, 3140, 3113, 1056, 3779,  654, 2360, 1901, 2908, 2169,  953, 1326, 2598]]], device='cuda:0'), with shape torch.Size([1, 1, 32])
print(f"image {img_path} is encoded into tokens {encoded_tokens}, with shape {encoded_tokens.shape}")
# de-tokenization
reconstructed_image = titok_tokenizer.decode_tokens(encoded_tokens)
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
reconstructed_image = Image.fromarray(reconstructed_image).save("assets/ILSVRC2012_val_00010240_recon.png")

# generate an image
sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
generated_image = demo_util.sample_fn(
    generator=titok_generator,
    tokenizer=titok_tokenizer,
    labels=sample_labels,
    guidance_scale=4.5,
    randomize_temperature=1.0,
    num_sample_steps=8,
    device=device
)
Image.fromarray(generated_image[0]).save(f"assets/generated_{sample_labels[0]}.png")

We also provide a jupyter notebook for a quick tutorial on reconstructing and generating images with TiTok-L-32.

We also support TiTok with HuggingFace 🤗 Demo!

Testing on ImageNet-1K Benchmark

We provide a sampling script for reproducing the generation results on ImageNet-1K benchmark.

# Prepare ADM evaluation script
git clone https://github.com/openai/guided-diffusion.git

wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
# Reproducing TiTok-L-32
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet.py config=configs/titok_l32.yaml experiment.output_dir="titok_l_32"
# Run eval script. The result FID should be ~2.77
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_l_32.npz

# Reproducing TiTok-B-64
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet.py config=configs/titok_b64.yaml experiment.output_dir="titok_b_64"
# Run eval script. The result FID should be ~2.48
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_b_64.npz

# Reproducing TiTok-S-128
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet.py config=configs/titok_s128.yaml experiment.output_dir="titok_s_128"
# Run eval script. The result FID should be ~1.97
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_s_128.npz

Training Prepartion

We use webdataset format for data loading. To begin with, it is needed to convert the dataset into webdataset format. An example script to convert ImageNet to wds format is provided here.

Furthermore, the stage1 training relies on a pre-trained MaskGIT-VQGAN to generate proxy codes as learning targets. You can convert the official Jax weight to PyTorch version using this script. Alternatively, we provided a converted version at HuggingFace and Google Drive.

Training

We provide example commands to train TiTok as follows:

# Training for TiTok-B64
# Stage 1
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage1/titok_b64.yaml \
    experiment.project="titok_b64_stage1" \
    experiment.name="titok_b64_stage1_run1" \
    experiment.output_dir="titok_b64_stage1_run1" \
    training.per_gpu_batch_size=32

# Stage 2
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage2/titok_b64.yaml \
    experiment.project="titok_b64_stage2" \
    experiment.name="titok_b64_stage2_run1" \
    experiment.output_dir="titok_b64_stage2_run1" \
    training.per_gpu_batch_size=32 \
    experiment.init_weight=${PATH_TO_STAGE1_WEIGHT}

# Train Generator (TiTok-B64 as example)
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP}--main_process_port=${ROOT_PORT} --same_network scripts/train_maskgit.py config=configs/training/generator/maskgit.yaml \
    experiment.project="titok_generation" \
    experiment.name="titok_b64_maskgit" \
    experiment.output_dir="titok_b64_maskgit" \
    experiment.tokenizer_checkpoint=${PATH_TO_STAGE1_or_STAGE2_WEIGHT}

The config titok_b64.yaml can be replaced with titok_s128.yaml or titok_l32.yaml for other TiTok variants.

Visualizations

teaser

teaser

Citing

If you use our work in your research, please use the following BibTeX entry.

@inproceedings{yu2024an,
  author    = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
  title     = {An Image is Worth 32 Tokens for Reconstruction and Generation},
  journal   = {NeurIPS},
  year      = {2024}
}

Acknowledgement

MaskGIT

Taming-Transformers

Open-MUSE

MUSE-Pytorch