/decomp_diffusion

[ICML 2024] Compositional Image Decomposition with Diffusion Models

Primary LanguagePython

Compositional Image Decomposition with Diffusion Models

We propose Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models.


This is the official codebase for Unsupervised Compositional Image Decomposition with Diffusion Models.

Compositional Image Decomposition with Diffusion Models
Jocelin Su 1*, Nan Liu 2*, Yanbo Wang 3*, Joshua B. Tenenbaum 1, Yilun Du 1,
* Equal Contribution
1MIT, 2UIUC, 3 TU Delft


The demo notebook shows how to train a model and perform experiments on decomposition, reconstruction, and recombination of factors on CLEVR, as well as recombination in multi-modal and cross-dataset settings.

  • The codebase is built upon Improved-Diffusion.
  • This codebase provides both training and inference code.

Setup

Run the following to create and activate a conda environment:

conda create -n decomp_diff python=3.8
conda activate decomp_diff

To install this package, clone this repository and then run:

pip install -e .

Training

We use a U-Net model architecture. To train a model, we specify its parameters and training parameters as follows:

MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
TRAIN_FLAGS="--batch_size 16 --dataset clevr --data_dir ../"

For distributed training, we run the following:

DEVICE=$CUDA_VISIBLE_DEVICES
NUM_DEVICES=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)

python -m torch.distributed.run --nproc_per_node=$NUM_DEVICES scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS

Otherwise, we run:

python scripts/image_train.py $MODEL_FLAGS $TRAIN_FLAGS --use_dist False

Inference

To generate images, we use a trained model and run a sampling loop, where DDPM sampling or DDIM sampling is specified. We provide pre-trained models for various datasets below. For example, a pre-trained CLEVR model is provided here.

To perform decomposition and reconstruction of an input image, run the following:

MODEL_CHECKPOINT="clevr_model.pt"
MODEL_FLAGS="--emb_dim 64 --enc_channels 128"
python scripts/gen_image_script.py --dataset clevr --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path sample_images/clevr_im_10.png --save_dir gen_clevr_img/ --sample_method ddim

In addition, we can generate results for multiple images in a dataset:

python scripts/gen_image_script.py --gen_images 100 --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --save_dir gen_many_clevr_imgs/

Decomp Diffusion can also compose discovered factors. To combine factors across 2 images, run:

python scripts/gen_image_script.py --combine_method slice --dataset $DATASET --ckpt_path $MODEL_CHECKPOINT $MODEL_FLAGS --im_path $IM_PATH --im_path2 $IM_PATH2 --save_dir $SAVE_DIR 

See gen_image_script.py for additional options such as generating additive combinations or cross-dataset combinations.


Datasets

See our paper for details on training datasets. Note that Tetris images are 32x32 instead of 64x64.

Dataset Link
CLEVR Link
CLEVR Toy Link
Tetris Link
CelebA-HQ 128x128 Link
KITTI Link
Virtual KITTI 2 Link
Falcor3D Link
Anime Link

Models

See our paper for details on model parameters for each dataset. We provide links to pre-trained models below, as well as their non-default parameter flags. We used --batch_size 32 during training.

Model Link Model Flags
CLEVR Link --emb_dim 64 --enc_channels 128
CelebA-HQ Link --enc_channels 128
Faces Link --enc_channels 128
CLEVR Toy Link --emb_dim 64 --enc_channels 128
Tetris --image_size 32 --num_components 3 --num_res_blocks 1 --enc_channels 64
VKITTI --num_channels 64 --enc_channels 64 --emb_dim 256
Combined KITTI --num_channels 64 --enc_channels 64 --emb_dim 256
Falcor3D --num_channels 64 --emb_dim 32 --channel_mult 1,2