/DETEX

Decoupled Textual Embeddings for Customized Image Generation (AAAI 2024)

Primary LanguagePython

Decoupled Textual Embeddings for Customized Image Generation

We propose a customized image generation method DETEX that utilizes multiple tokens to alleviate the issue of overfitting and entanglement between the target concept and unrelated information. Our DETEX enables more precise and efficient control over preserving input image content in the generated results during inference by selectively utilizing different tokens.

Method Details

Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement.Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings, $i.e.$, an image-shared subject embedding $v$ and two image-specific subject-unrelated embeddings (pose $v^p_i$ and background $v^b_i$). Right: To learn target concept, we initialize the subject embedding $v$ as a learnable vector, and adopt two attribute mappers to project the input image as the pose and background embeddings. During training, we jointly finetune the embeddings with the K, V mapping parameters in cross-attention layer. A cross-attention loss is further introduced to facilitate the disentanglement.

Getting Started

Environment Setup

git clone https://github.com/PrototypeNx/DETEX.git
cd DETEX
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm

Our code was developed on the following commit #21f890f9da3cfbeaba8e2ac3c425ee9e998d5229 of stable-diffusion. Download the stable-diffusion model checkpoint wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt

The pretrained CLIP model can be downloaded automatically. If that doesn't work, you can download the clip-vit-large-patch14 manually and place it in the appropriate config folder.

Preparing Dataset

We provide some processed example images in data which contains original images and corresponding processed foreground images and masks mentioned in the paper.

For custom dataset, you should prepare the original image SubjectName belong to a specific concept, the corresponding mask SubjectName_mask, and the corresponding foreground image SubjectName_fg. Note that the mask and foreground image files should have the same file name xxx0n.png as their corresponding original image.

We recommend using SAM to simply obtain the foreground mask and the corresponding foreground image.

In addition, it is necessary to prepare a regularized dataset which contains images belong to the same category of the input subject. You can retrieve the images on the website or just generate with vanilla SD using prompt like 'Photo of a <category>'. We recommend preparing at least 200 regularized images for each category to achieve better performence. More details about regularization can be found in Dreambooth.

The data structure should be like this:

data
├── SubjectName
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── SubjectName_fg
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── SubjectName_mask
│  ├── xxx01.png
│  ├── xxx02.png
│  ├── xxx03.png
│  ├── xxx04.png
├── Subject_samples
│  ├── 001.png
│  ├── 002.png
│  ├── ....
│  ├── 199.png
│  ├── 200.png

Training

You can run the scripts below to train with the example data.

## run training (on 4 GPUs)
python -u  train.py \
            --base configs/DETEX/finetune.yaml  \
            -t --gpus 0,1,2,3 \
            --resume-from-checkpoint-custom <path-to-pretrained-sd> \
            --caption "<new1> dog with <p> pose in <b> background" \
            --num_imgs 4 \
            --datapath data/dog7 \
            --reg_datapath data/dog_samples/samples \
            --mask_path data/dog7_fg\
            --mask_path2 data/dog7_mask\
            --reg_caption "dog" \
            --modifier_token "<new1>+<p1>+<p2>+<p3>+<p4>+<b1>+<b2>+<b3>+<b4>" \
            --name dog7

The modifier tokens <p1>~<p4> and <b1>~<b4> represent the corresponding pose and background of the 4 input imgs respectively. Please refer to the paper for more details about the unrelated tokens.

Note that the parameter modifier_token should be arranged in the form <new1>+<p1>+...+<pn>+<b1>+...+<bn>. Do not change the input order of <new1>, <p> and <b>.

If you don't have a sufficient number of GPUs, we recommend training with a lower learning rate for more iterations.

Save Updated Checkpoint

After training, run the following script to only save the updated weights.

python src/get_deltas.py --path logs/<folder-name>/checkpoints/last.ckpt --newtoken 9

Generation

Run the following script to generate with the target concept subject <new1>.

python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
                 --ckpt <path-to-pretrained-sd> --scale 6  --n_samples 3 --n_iter 2 --ddim_steps 50 \
                 --prompt "photo of a <new1> dog"

If you use unrelated token <p> or <b> in the prompt, a reference img path should be added in the script to get the unrelated embedding through mapper.

python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
                 --ckpt <path-to-pretrained-sd> --scale 6  --n_samples 3 --n_iter 2 --ddim_steps 50 \
                 --ref data/dog7/02.png \
                 --prompt "photo of a <new1> dog running in <b2> background"

The generated images are saved in logs/<folder-name>.

Citation

@article{cai2023DETEX,
    title={Decoupled Textual Embeddings for Customized Image Generation}, 
    author={Yufei Cai and Yuxiang Wei and Zhilong Ji and Jinfeng Bai and Hu Han and Wangmeng Zuo},
    journal={arXiv preprint arXiv:2312.11826},
    year={2023}
}