Optimizing CLIP Models for Image Retrieval with Maintained Joint-Embedding Alignment (ArXiv)
Konstantin Schall, Kai Uwe Barthel, Nico Hezel, Klaus Jung
Visual Computing Group HTW Berlin
Contrastive Language and Image Pairing (CLIP), a transformative method in multimedia retrieval, typically trains two neural networks concurrently to generate joint embeddings for text and image pairs. However, when applied directly, these models often struggle to differentiate between visually distinct images that have similar captions, resulting in suboptimal performance for image-based similarity searches. This paper addresses the challenge of optimizing CLIP models for various image-based similarity search scenarios, while maintaining their effectiveness in text-based search tasks such as text-to-image retrieval and zero-shot classification. We propose and evaluate two novel methods aimed at refining the retrieval capabilities of CLIP without compromising the alignment between text and image embeddings. The first method involves a sequential fine-tuning process: initially optimizing the image encoder for more precise image retrieval and subsequently realigning the text encoder to these optimized image embeddings. The second approach integrates pseudo-captions during the retrieval-optimization phase to foster direct alignment within the embedding space. Through comprehensive experiments, we demonstrate that these methods enhance CLIP's performance on various benchmarks, including image retrieval, k-NN classification, and zero-shot text-based classification, while maintaining robustness in text-to-image retrieval. Our optimized models permit maintaining a single embedding per image, significantly simplifying the infrastructure needed for large-scale multi-modal similarity search systems.
We propose two methods that significantly improve pre-trained CLIP models for image-to-image retrieval, while preserving the joint-embedding alignment and text-based task qualities.
The second method, Multi-Caption-Image-Pairing (MCIP), leads to the best results across all models:
open_clip Name | open_clip pretrained | Optimized Checkpoint |
---|---|---|
ViT-L-14-336 | openai | checkpoint |
ViT-SO400M-14-SigLIP-384 | webli | checkpoint |
If you want to try out models, you simply have to install open_clip, download one of the above checkpoints, create the respective open_clip model instance and load our weights. That's it!
import open_clip
model, _, transform = open_clip.create_model_and_transforms("ViT-SO400M-14-SigLIP-384", pretrained="webli")
checkpoint_path = '/path/to/checkpoint.pth'
mcip_state_dict = torch.load(checkpoint_path)
model.load_state_dict(mcip_state_dict, strict=True)
This repository now contains code for each of the fine-tuning methods mentioned in the paper!
Create a new Python environment and install the required packages:
pip install -r requirements.txt
We used a combination of five publicly available training sets for the general-purpose retrieval and MCIP fine-tuning:
- ImageNet21k (Classes from ImageNet1k were excluded). Download instructions
- Google Landmarks v2 Download instructions
- Alibaba Products Download instructions
- iNat 2021 Download instructions
- VGG Face2 Download instructions
However, you can use any combination of data you like. Preprocess your training dataset and store it in two numpy files. One should contain all paths to images that you want to use. The second should contain the unique category identifiers for each of the categories in the training data for each image. These files should have the same order as the first file. The identifiers can be integers or strings, but should not overlap between the dataset parts.
Adjust the config file of your experiment and set the parameters:
data:
image_paths: /mnt/data/features/GPR_Full_paths.npy
image_categories: /mnt/data/features/GPR_Full_cats.npy
This set will track the progress of the general-purpose retrieval fine-tuning. Download the images under this link and extract the zip file. Adjust the config parameter:
data:
eval_base_path: /path/to/evaluation_base_folder/
This set is currently used to generate MCIP pseudo-captions and for realignment fine-tuning. However, any text-image pairs collection should also work. Follow the img2dataset example to obtain a sharded version of this dataset.
This step will only train the image-encoder with the pre-processed data. The text-encoder will have to be realigned in an additional step. Run the following command to start fine-tuning:
python gpr-ft/main.py --mode train --cfg gpr-ft/config/GPR_SigLIP400_ArcFaceText_AdamW_384.yaml --gpu_id 0
The model weights will be saved in the logs folder under the name corresponding to the config file.
First, we need to create the pseudo-captions.
Set the parameter:
data:
train_image_embeddings_file_path: ""
Extract the image embeddings of the GPR fine-tuning dataset:
python gpr-ft/main.py --mode extract_image_features --cfg gpr-ft/config/SigLIP400_MCIP_AdamW_384.yaml --gpu_id 0
Extract the text values and embeddings of the realignment dataset:
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 -m realignment.main \
--train-data '/path/to/CC12M/cc12m/{00000..01242}.tar' \
--train-num-samples 10968539 \
--dataset-type webdataset \
--batch-size 16384 \
--precision amp \
--workers 16 \
--model ViT-SO400M-14-SigLIP-384 \
--pretrained webli \
--extract-text-features-only \
--store_features_path features/
The text value and embedding files will be saved to the specified folder.
Set the paths of the newly generated files in the config file parameters and the desired similarity threshold:
data:
train_image_embeddings_file_path: "features/GPR_Train_OClipL_imagefeatures.npy"
caption_value_files: "features/12CCM_OClipL_text_0.npy;features/12CCM_OClipL_text_1.npy"
caption_embeddings_files: "features/12CCM_OClipL_textfeatures_0.pth;features/12CCM_OClipL_textfeatures_1.pth"
caption_output_dir: captions
MCIP:
sim_th: 0.27
Generate pseudo-captions:
python gpr-ft/main.py --mode create_pseudo_captions --cfg gpr-ft/config/GPR_SigLIP400_TextArcFaceText_AdamW_384.yaml --gpu_id 0
Run the MCIP fine-tuning script:
python gpr-ft/main.py --mode train_with_text --cfg gpr-ft/config/SigLIP400_MCIP_AdamW_384.yaml --gpu_id 0
Both methods benefit from a realignment fine-tuning, and it is necessary to restore the multi-modal capabilities in case you fine-tuned your model without MCIP.
Since the improved image-encoder should not be changed anymore, the image embeddings can be extracted once before the realignment fine-tuning begins:
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 -m realignment.main \
--train-data '/path/to/images/CC12M/cc12m/{00000..01242}.tar' \
--train-num-samples 10968539 \
--dataset-type webdataset \
--batch-size 512 \
--precision amp \
--workers 16 \
--model ViT-SO400M-14-SigLIP-384 \
--GPR-model-weights '../logs/gpr-ft/exp_name/GPR1200.pth' \
--pretrained webli \
--force-image-size 384 \
--image-mean 0.5 0.5 0.5 \
--image-std 0.5 0.5 0.5 \
--extract-features-only
Start the realignment fine-tuning with the extracted embeddings:
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node 2 -m realignment.main \
--train-data '/mnt/bigdata/images/CC12M/cc12m/{00000..01242}.tar' \
--train-num-samples 10968539 \
--dataset-type webdataset \
--batch-size 16384 \
--precision amp \
--workers 16 \
--imagenet-val /mnt/data/images/ImageNet1k_2012/val \
--model ViT-SO400M-14-SigLIP-384 \
--GPR-model-weights '../logs/gpr-ft/exp_name/GPR1200.pth' \
--pretrained webli \
--grad-checkpointing \
--lock-image \
--lock-image-unlocked-groups 0 \
--lock-image-freeze-bn-stats \
--force-image-size 384 \
--dataset-resampled \
--lr 0.00001 \
--lock-text \
--lock-text-unlocked-layers 16 \
--delete-previous-checkpoint \
--train-with-features-only \
--log-every-n-steps 5 \
--image-embedding-files /mnt/bigdata/features/12CCM_S400_imagefeatures_0.pth /mnt/bigdata/features/12CCM_S400_imagefeatures_1.pth \
--image-key-files /mnt/bigdata/features/12CCM_S400_keys_0.npy /mnt/bigdata/features/12CCM_S400_keys_1.npy \
--image-embedding-dim 1152
Your model now has improved image-to-image nearest-neighbor search capabilities and performs equally well or better in textual zero-shot classification and text-to-image retrieval!
The realignment fine-tuning code base is heavily based on open clip. Thanks to all contributors of this awesome libary!