We propose a customized image generation method DETEX that utilizes multiple tokens to alleviate the issue of overfitting and entanglement between the target concept and unrelated information. Our DETEX enables more precise and efficient control over preserving input image content in the generated results during inference by selectively utilizing different tokens.
Framework of our DETEX. Left: Our DETEX represents each image with multiple decoupled textual embeddings,
git clone https://github.com/PrototypeNx/DETEX.git
cd DETEX
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm
Our code was developed on the following commit #21f890f9da3cfbeaba8e2ac3c425ee9e998d5229
of stable-diffusion. Download the stable-diffusion model checkpoint
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
The pretrained CLIP model can be downloaded automatically. If that doesn't work, you can download the clip-vit-large-patch14 manually and place it in the appropriate config folder.
We provide some processed example images in data which contains original images and corresponding processed foreground images and masks mentioned in the paper.
For custom dataset, you should prepare the original image SubjectName
belong to a specific concept, the corresponding mask SubjectName_mask
, and the corresponding foreground image SubjectName_fg
. Note that the mask and foreground image files should have the same file name xxx0n.png
as their corresponding original image.
We recommend using SAM to simply obtain the foreground mask and the corresponding foreground image.
In addition, it is necessary to prepare a regularized dataset which contains images belong to the same category of the input subject. You can retrieve the images on the website or just generate with vanilla SD using prompt like 'Photo of a <category>'
. We recommend preparing at least 200 regularized images for each category to achieve better performence. More details about regularization can be found in Dreambooth.
The data structure should be like this:
data
├── SubjectName
│ ├── xxx01.png
│ ├── xxx02.png
│ ├── xxx03.png
│ ├── xxx04.png
├── SubjectName_fg
│ ├── xxx01.png
│ ├── xxx02.png
│ ├── xxx03.png
│ ├── xxx04.png
├── SubjectName_mask
│ ├── xxx01.png
│ ├── xxx02.png
│ ├── xxx03.png
│ ├── xxx04.png
├── Subject_samples
│ ├── 001.png
│ ├── 002.png
│ ├── ....
│ ├── 199.png
│ ├── 200.png
You can run the scripts below to train with the example data.
## run training (on 4 GPUs)
python -u train.py \
--base configs/DETEX/finetune.yaml \
-t --gpus 0,1,2,3 \
--resume-from-checkpoint-custom <path-to-pretrained-sd> \
--caption "<new1> dog with <p> pose in <b> background" \
--num_imgs 4 \
--datapath data/dog7 \
--reg_datapath data/dog_samples/samples \
--mask_path data/dog7_fg\
--mask_path2 data/dog7_mask\
--reg_caption "dog" \
--modifier_token "<new1>+<p1>+<p2>+<p3>+<p4>+<b1>+<b2>+<b3>+<b4>" \
--name dog7
The modifier tokens <p1>~<p4>
and <b1>~<b4>
represent the corresponding pose and background of the 4 input imgs respectively. Please refer to the paper for more details about the unrelated tokens.
Note that the parameter modifier_token
should be arranged in the form <new1>+<p1>+...+<pn>+<b1>+...+<bn>
. Do not change the input order of <new1>
, <p>
and <b>
.
If you don't have a sufficient number of GPUs, we recommend training with a lower learning rate for more iterations.
After training, run the following script to only save the updated weights.
python src/get_deltas.py --path logs/<folder-name>/checkpoints/last.ckpt --newtoken 9
Run the following script to generate with the target concept subject <new1>
.
python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
--ckpt <path-to-pretrained-sd> --scale 6 --n_samples 3 --n_iter 2 --ddim_steps 50 \
--prompt "photo of a <new1> dog"
If you use unrelated token <p>
or <b>
in the prompt, a reference img path should be added in the script to get the unrelated embedding through mapper.
python sample.py --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch_last.ckpt \
--ckpt <path-to-pretrained-sd> --scale 6 --n_samples 3 --n_iter 2 --ddim_steps 50 \
--ref data/dog7/02.png \
--prompt "photo of a <new1> dog running in <b2> background"
The generated images are saved in logs/<folder-name>
.
@article{cai2023DETEX,
title={Decoupled Textual Embeddings for Customized Image Generation},
author={Yufei Cai and Yuxiang Wei and Zhilong Ji and Jinfeng Bai and Hu Han and Wangmeng Zuo},
journal={arXiv preprint arXiv:2312.11826},
year={2023}
}