
Fine-tune diffusion models on custom datasets and sample with text-conditioning using CLIP guidance combined with SwinIR for super resolution.

Primary LanguagePythonApache License 2.0Apache-2.0

CLIP Diffusion Art

Fine-tune diffusion models on custom datasets and sample with text-conditioning using CLIP guidance and SwinIR for super resolution.

📌 Dataset with public domain artworks created for this project:

 Artworks in Public Domain

📌 Link to interactive run in notebook:

 Stunning Art with CLIP Guided Diffusion+SwinIR

📌 Wandb logging is integrated for training and sampling.

Generated Samples

"vibrant watercolor painting of a flower, artstation HQ"

"beautiful matte painting of dystopian city, Behance HD"

"vibrant watercolor painting of a flower, artstation HQ"

"artstation HQ, photorealistic depiction of an alien city"

For more generated artworks, visit this report

Super-resolution Results


Developed using techniques and architectures borrowed from original work by the authors below:

Huge thanks to all their great work! I highly recommend checking out these repos.


git clone https://github.com/sreevishnu-damodaran/clip-diffusion-art.git -q
cd clip-diffusion-art
pip install -e . -q
git clone https://github.com/JingyunLiang/SwinIR.git -q
git clone https://github.com/crowsonkb/guided-diffusion -q
pip install -e guided-diffusion -q
git clone https://github.com/openai/CLIP -q
pip install -e ./CLIP -q


Public Domain Artworks dataset used in this repo:


Additional details datasets/README.md

Training & Fine-tuning

Chooose the hyperparameters for training. These are resonable defaults to fine-tune on a custom dataset with a 16GB GPUs on Colab or Kaggle:

MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --attention_resolutions 16"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --learn_sigma True --rescale_learned_sigmas True --rescale_timesteps True --use_scale_shift_norm False"
TRAIN_FLAGS="--lr 5e-6 --save_interval 500 --batch_size 16 --use_fp16 True --wandb_project diffusion-art-train --use_checkpoint True --resume_checkpoint pretrained_models/lsun_uncond_100M_1200K_bs128.pt"

Once the hyperparameters are set, run the traning job as follows:

python clip_diffusion_art/train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

Refer to the openai improved diffusion for more details on choosing hyperparameters and to select other pre-trained weights.

Download SR pre-trained weights

wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth

Passing the sr_model_path flag to sample.py performs super-resolution to each image after sampling.

Sample Images with CLIP Guidance

python clip_diffusion_art/sample.py \
"beautiful matte painting of dystopian city, Behance HD" \
--checkpoint 256x256_clip_diffusion_art.pt \
--model_config "clip_diffusion_art/configs/256x256_clip_diffusion_art.yaml" \
--sampling "ddim50" \
--cutn 60 \
--cut_batches 4 \
--sr_model_path pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
--large_sr \
--output_dir "outputs"


--images - image prompts (default=None)
--checkpoint - diffusion model checkpoint to use for sampling
--model_config - diffusion model config yaml
--wandb_project - enable wandb logging and use this project name
--wandb_name - optinal run name to use for wandb logging
--wandb_entity - optinal entity to use for wandb logging
--num_samples - - number of samples to generate (default=1)
--batch_size - default=1batch size for the diffusion model
--sampling - timestep respacing sampling methods to use (default="ddim50", choices=[25, 50, 100, 150, 250, 500, 1000, ddim25, ddim50, ddim100, ddim150, ddim250, ddim500, ddim1000])
--diffusion_steps - number of diffusion timesteps (default=1000)
--skip_timesteps - diffusion timesteps to skip (default=5)
--clip_denoised - enable to filter out noise from generation (default=False)
--randomize_class_disable - disables changing imagenet class randomly in each iteration (default=False)
--eta - the amount of noise to add during sampling (default=0)
--clip_model - CLIP pre-trained model to use (default="ViT-B/16", choices=["RN50","RN101","RN50x4","RN50x16","RN50x64","ViT-B/32","ViT-B/16","ViT-L/14"])
--skip_augs - enable to skip torchvision augmentations (default=False)
--cutn - the number of random crops to use (default=16)
--cutn_batches - number of crops to take from the image (default=4)
--init_image - init image to use while sampling (default=None)
--loss_fn - loss fn to use for CLIP guidance (default="spherical", choices=["spherical" "cos_spherical"])
--clip_guidance_scale - CLIP guidance scale (default=5000)
--tv_scale - controls smoothing in samples (default=100)
--range_scale - controls the range of RGB values in samples (default=150)
--saturation_scale - controls the saturation in samples (default=0)
--init_scale - controls the adherence to the init image (default=1000)
--scale_multiplier - scales clip_guidance_scale tv_scale and range_scale (default=50)
--disable_grad_clamp - disable gradient clamping (default=False)
--sr_model_path - SwinIR super-resolution model checkpoint (default=None)
--large_sr - enable to use large SwinIR super-resolution model (default=False)
--output_dir - output images directory (default="output_dir")
--seed - the random seed (default=47)
--device - the device to use

Apply Super-resolution

Use the following to run super-resolution on other images or use it for other tasks (grayscale/color image denoising/JPEG compression artifact reduction)

python swinir.py <path-to-images-dir> --task "real_sr"

data_dir - directory with images

--task - image restoration task (default='real_sr', choices=['real_sr', 'color_dn', 'gray_dn', 'jpeg_car'])