/DiM-DiffusionMamba

The official implementation of DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis

Primary LanguagePython

DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis

The official implementation of our paper DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis.

drawing

drawing

drawing

Method Overview

drawing

Acknowledge

This code is mainly built on U-ViT and Mamba.

Installing Mamba may cost a lot of effort. If you encounter problems, this issues in Mamba may be very helpful.

Installation

# create env:
conda env create -f environment.yaml

# if you want to update the env `mamba` with the contents in `~/mamba_attn/environment.yaml`:
conda env update --name mamba --file ~/mamba_attn/environment.yaml --prune

# Compiling Mamba. You need to successfully install causal-conv1d first.
CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install --user -e .
# If failing to compile, copy the files in ./build/ on another server which has compiled successfully; Maybe --user is necessary.

Preparation Before Training and Evaluation

Please follow U-ViT, the same subtitle.

Checkpoints

Model FID training iterations batch size
ImageNet 256x256 (Huge/2) 2.40 425K 768
ImageNet 512x512 (fine-tuned Huge/2) 3.94 Fine-tune 240

Note:

We use nnet_ema.pth for evaluation instead of nnet.pth.

nnet.pth is the trained model, while nnet_ema.pth is the EMA of model weights.

Evaluation

Use eval_ldm_discrete.py for evaluation and generating images with CFG

# ImageNet 256x256 
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet256_H_DiM.py --nnet_path='workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth'

# ImageNet 512x512
# The generated 512x512 images for evaluation cost ~22G.
# So I recommend setting a path to `config.sample.path` in the config `imagenet512_H_DiM_ft` if the space is tight for temporary files.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

# ImageNet 512x512 upsample 2x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

# ImageNet 512x512 upsample 3x, the generated images are in `workdir/imagenet512_H_DiM_ft/test_tmp` which is set in config.
accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 20039 --num_processes 8 --mixed_precision bf16 ./eval_ldm_discrete.py --config=configs/imagenet512_H_DiM_upsample_3x_test.py --nnet_path='workdir/imagenet512_H_DiM_ft/default/ckpts/64000.ckpt/nnet_ema.pth'

Training

# Cifar 32x32
accelerate launch --multi_gpu --num_processes 8 --mixed_precision fp16 ./train_ldm_discrete.py --config=configs/imagenet256_H_DiM.py

# ImageNet 256x256 
accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet256_H_DiM.py

# ImageNet 512x512 
# Fine-tuning, and you need to carefully check whether
# the pre-trained weights are in `workdir/imagenet256_H_DiM/default/ckpts/425000.ckpt/nnet_ema.pth`.
# This location is set in the config file: `config.nnet.pretrained_path`.
# If there is no such ckpt, no pre-training weight will be loaded.
accelerate launch --multi_gpu --num_processes 8 --mixed_precision bf16 ./train_ldm_discrete.py --config=configs/imagenet512_H_DiM_ft.py

Citation

@misc{teng2024dim,
      title={DiM: Diffusion Mamba for Efficient High-Resolution Image Synthesis}, 
      author={Yao Teng and Yue Wu and Han Shi and Xuefei Ning and Guohao Dai and Yu Wang and Zhenguo Li and Xihui Liu},
      year={2024},
      eprint={2405.14224},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}