
Fine-tune OpenAI's CLIP model for classification tasks

Finetune CLIP

Fine-tune OpenAI's CLIP model for classification tasks on a subset of the ImageNet dataset, ImageNette, and benchmark its performance against state-of-the-art (SOTA) classification models such as ResNets models.


CLIP is being used for various tasks ranging from semantic image search to zero-shot image labeling. It also plays a crucial role in the architecture of Stable Diffusion and is integral to the recently emerging field of large multimodal models (LMMs). This repo will use CLIP for classification tasks and



Used FastAI's Imagenette dataset for my experiments. You can check it here https://github.com/fastai/imagenette

Prepare the dataset by renaming the directories for each classes to the class names. Your Dataset should be like this.

_ imagenette2
|__ cassette player
|__ chain_saw
|__ church
|__ english springer 
|__ French_horn
|__ garbage_truck
|__ gas_pump
|__ golf_ball
|__ parachute
|__ tench

Train baseline model

  • Available baseline classification models - resnet18, resnet34,resnet50 , resnet101, resnet152, densenet121, densenet169, densenet201, densenet161, efficientnetb0, googlenet, mobilenet, mobilenetv2, vgg11, vgg13, vgg16, vgg19.
  • Supported Datasets - ImageNette, Cifar10
python train_baseline.py --dataset imagenette --model resnet18 --dataset_path <path to your dataset>

Finetune clip model

  • Available CLIP models - RN50, RN101, RN50x4, RN50x16, RN50x64, ViT-B/32, ViT-B/16, ViT-L/14, ViT-L/14@336px
  • Supported Datasets - ImageNette
python clip_finetune.py --model ViT-B/32 --dataset_path <path to your dataset>

Tensorboard logs

tensorboard --logdir=runs

Evaluation accuracy for Resnet18 and CLIP ViT-B/32 on ImageNette dataset

(Not very fair enough to compare those two models tho)


ResNet18 - 65.42%

CLIP ViT-B/32 - 99.59%

Export ONNX

python scripts/export_onnx.py --input_pytorch_model best.pt --output_onnx_model best.onnx

This only works for models like ResNet, DenseNet models. Still working for CLIP model.


python scripts/quantize.py --input_pytorch_model best.pt --output_quantized_model best_quantized.pt



