/Shifted_Diffusion

Code for Shifted Diffusion for Text-to-image Generation (CVPR 2023)

Primary LanguagePythonCreative Commons Zero v1.0 UniversalCC0-1.0

Shifted Diffusion for Text-to-image Generation

examples

Code for Shifted Diffusion for Text-to-image generation (CVPR 2023).

Shifted Diffusion is a new diffusion model designed to better generate image embeddings from text.

framework

("Decoder" can be either diffusion-based or GAN-based model, you can also make it conditioned on both image embedding and text.)

With Shifted Diffusion, you can

  • improve your text-to-image generation model by introducing an extra image embedding input (see section 5.1);
  • train or fine-tune a text-to-image generation model on image-only dataset (so-called language-free setting);

Below we provide examples of using our Shifted Diffusion.

Don't forget to create a new conda environment in advance.

Get started

Install some dependencies

pip install -r ./requirements.txt
pip install git+https://github.com/openai/CLIP.git
cd ./diffusers
pip install -e .
cd ..
wget "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
wget "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
wget "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
accelerate config

To train a Shifted Diffusion model, run (choose hyper-parameters based on your device)

accelerate launch --mixed_precision="fp16" train.py

We provide our pre-trained Shifted Diffusion models here.

Shifted Diffusion + Stable Diffusion

We provide a simple example which combines our pre-trained Shifted Diffusion with Stable Diffusion 2.

Specifically, a projection layer is added, which maps input image embedding into 4 word embeddings. Feel free to try more complicated architectures.

With the example below, one can first fine-tune a Stable Diffusion model on image-only dataset (language-free setting), then

  • directly input an image to perform image-to-image generation;
  • directly plug in our pre-trained Shifted Diffusion model and perform text-to-image generation;

Fine-tune a Stable Diffusion model

Prepare an image-only dataset (MS-COCO for example)

wget http://images.cocodataset.org/zips/train2014.zip
unzip train2014.zip
python process_img.py --src=./train2014 --size=512 --dest=./train2014

Run

accelerate launch --mixed_precision="fp16" finetune.py\
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-base" \
  --train_data_dir=./train2014/ \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=8 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --max_train_steps=30000 \
  --checkpointing_steps=5000\
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="./finetuned_coco"

(We did not optimize hyper-parameters, hyper-parameters follow examples here)

Here are some slightly fine-tuned Stable Diffusion 2 models, we used a total batch size of 8 * 8 * 1 = 64.

(The models are "slightly fine-tuned", which means we only fine-tuned them for 10k~30k steps, just for example purpose. More fine-tuning steps with better tuned hyper-parameters will lead to better results.)

Test fine-tuned Stable Diffusion model

Generate image with CLIP image embedding

Run

python test.py

Examples of input/generated images on different datasets:

pelican re_pelicantrain re_train
face re_faceface re_face


Generate image with text + Shifted Diffusion

Run

python sft_test.py

Below we provide a comparison.

yellow-and-blue-train

A ground-truth image-text pair is shown, obtained from MS-COCO dataset.

Although Stable Diffusion 2 is able to perform zero-shot generation, the generation may not satisfy our requirement in terms of style, etc.

With our language-free fine-tuning and pre-trained Shifted Diffusion model, we are able to generate desired images.

This approach can be easily applied to different domains/datasets, no image-text pair is needed in fine-tuning.

Below is a comparison between shifted diffusion and baseline diffusion on fine-tuned Stable Diffusion 2 model, where we evaluate the FID score and CLIP similarity (average similarity from CLIP ViT-B/16, ViT-B/32, RN-101) between generated images with input text/ground-truth target images.

FID-CLIP-img-img FID-CLIP-img-text

Shifted Diffusion + Lafite

The decoder can also be GAN-based models, e.g. Lafite.

Similar to the example above, one need to construct an image-only dataset, then train a mapping which maps image embeddings to images.

After training of GAN, directly utilize pre-trained Shifted Diffusion model to perform text-to-image generation at inference.

Citation

@article{zhou2022shifted,
  title={Shifted Diffusion for Text-to-image Generation},
  author={Zhou, Yufan and Liu, Bingchen and Zhu, Yizhe and Yang, Xiao and Chen, Changyou and Xu, Jinhui},
  journal={arXiv preprint arXiv:2211.15388},
  year={2022}
}