/CoOp

Prompt Learning for Vision-Language Models

Primary LanguagePythonMIT LicenseMIT

Prompt Learning for Vision-Language Models

This repo contains the codebase of a series of research projects focused on adapting vision-language models like CLIP to downstream datasets via prompt learning:

Updates

  • 10.06.2022: Our latest work, Neural Prompt Search, has just been released on arxiv. It provides a novel perspective for fine-tuning large vision models like ViT, so please check it out if you're interested in parameter-efficient fine-tuning/transfer learning. The code is also made public here.

  • 08.06.2022: If you're looking for the code to draw the few-shot performance curves (like the ones we show in the CoOp's paper), see draw_curves.py.

  • 09.04.2022: The pre-trained weights of CoOp on ImageNet are released here.

  • 11.03.2022: The code of our CVPR'22 paper, "Conditional Prompt Learning for Vision-Language Models," is released.

  • 15.10.2021: We find that the best_val model and the last_step model achieve similar performance, so we set TEST.FINAL_MODEL = "last_step" for all datasets to save training time. Why we used best_val: the (tiny) validation set was designed for the linear probe approach, which requires extensive tuning for its hyperparameters, so we used the best_val model for CoOp as well for fair comparison (in this way, both approaches have access to the validation set).

  • 09.10.2021: Important changes are made to Dassl's transforms.py. Please pull the latest commits from https://github.com/KaiyangZhou/Dassl.pytorch and this repo to make sure the code works properly. In particular, 1) center_crop now becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio), and 2) for training, Resize(cfg.INPUT.SIZE) is deactivated when random_crop or random_resized_crop is used. Please read this issue on how these changes might affect the performance.

  • 18.09.2021: We have fixed an error in Dassl which could cause a training data loader to have zero length (so no training will be performed) when the dataset size is smaller than the batch size (due to drop_last=True). Please pull the latest commit for Dassl (>= 8eecc3c). This error led to lower results for CoOp in EuroSAT's 1- and 2-shot settings (others are all correct). We will update the paper on arxiv to fix this error.

How to Install

This code is built on top of the awesome toolbox Dassl.pytorch so you need to install the dassl environment first. Simply follow the instructions described here to install dassl as well as PyTorch. After that, run pip install -r requirements.txt under CoOp/ to install a few more packages required by CLIP (this should be done when dassl is activated). Then, you are ready to go.

Follow DATASETS.md to install the datasets.

How to Run

Click a paper below to see the detailed instructions on how to run the code to reproduce the results.

Models and Results

  • The pre-trained weights of CoOp (both M=16 & M=4) on ImageNet based on RN50, RN101, ViT-B/16 and ViT-B/32 can be downloaded altogether via this link. The weights can be used to reproduce the results in Table 1 of CoOp's paper (i.e., the results on ImageNet and its four variants with domain shift). To load the weights and run the evaluation code, you will need to specify --model-dir and --load-epoch (see this script for example).
  • The raw numerical results can be found at this google drive link.

Citation

If you use this code in your research, please kindly cite the following papers

@inproceedings{zhou2022cocoop,
    title={Conditional Prompt Learning for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    booktitle={CVPR},
    year={2022}
}

@article{zhou2021coop,
    title={Learning to Prompt for Vision-Language Models},
    author={Zhou, Kaiyang and Yang, Jingkang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2109.01134},
    year={2021}
}