/lgssl

[CVPR 2023] Learning Visual Representations via Language-Guided Sampling

Primary LanguagePythonMIT LicenseMIT

Learning Visual Representations via Language-Guided Sampling

Learning Visual Representations via Language-Guided Sampling
Mohamed El Banani, Karan Desai, and Justin Johnson

If you have any questions, please feel free to email me at mbanani@umich.edu.

Environment Setup

We recommend using Anaconda or Miniconda. To setup the environment, follow the instructions below.

conda create -n lgssl python=3.8 --yes
conda activate lgssl
conda install pytorch=1.12.1 torchvision cudatoolkit=11.3 -c pytorch --yes

python -m pip install -r requirements.txt
python setup.py develop

Training Datasets

Expand

We train our models on RedCaps and ConceptualCaptions (CC3M and CC12M). We note that all 3 datasets can decay, so you might end up with a different number of instances. Please refer to the original papers for dataset download instructions. In our case, the datasets had the following sizes:

Dataset Size
RedCaps-2020 3273223
RedCaps 12010494
CC3M 2913035
CC12M 10958691

We assume all training datasets are in data/datasets which is set as the default data_root in the base dataset class. We expect the dataset to be in the format below where each dataset is subdivided into several directories and each directory contains a set of instances where each instance has an image file and a json caption file.

data/datasets/<dataset_name>
    |- directory_0
        |- <instance_0>.jpg     <- image for instance 0
        |- <instance_0>.json    <- caption for instance 0
        |- <instance_1>.jpg
        |- <instance_1>.json
        ...
        |- <instance_n>.jpg
        |- <instance_n>.json
    |- directory_1
    |- directory_2
    ...
    |- directory_m

For RedCaps, the directory names are encoded as <subreddit>_<year>_<id>, e.g., crochet_2017_000001, where each directory only has 10000 classes. We use this naming convention for some of the experiments: experiments with redcaps-2020 and sampling scope.

Generating dataset dictionaries

We create dataset specific dictionaries that contain the information for each dataset (eg, image paths, captions) which allow for easy sampling in subsequent steps. To generate a dataset dictionary, run the following code where <dataset_name> is the name of the dataset repo in data/datasets.

cd preprocess
python make_imagecaption_dict.py <dataset_name> 

Sampling nearest neighbor pairs

Once we have the dataset dictionaries, we can easily sample nearest neighbor pairs. We provide the code for sampling using language or visual embeddings. We also provide the sampling based on dataset subsets for the experiments reported in supplementary. Check the commands below for language sampling based on SBERT, visual sampling based on an ImageNet pretrained model, and language sampling within each subreddit.

python sample_language_nn.py <dataset_name> all-mpnet-base-v2                       # Language - MPNet (SBERT)
python sample_language_nn_subsets.py <dataset_name> all-mpnet-base-v2 subreddit     # Language Subset - MPNet (SBERT) on subreddits

python sample_visual_nn.py <dataset_name> vit_b_32 IMAGENET1K_V1                    # Visual - ImageNet-supervised ViT-B/32

Evaluation Datasets

Expand

We use TensorFlow Datasets for our evaluations. This package provides us with all the evaluations except for FGVC Aircraft. Our code will automatically download and extract all the datasets in data/evaluation_datasets on the first run of the evaluation code. This means that the first evaluation run will be much slower than usual.

Note 1: We encountered a bug with SUN 397 where one image could not be decoded correctly. This is a known bug which has not been fixed yet in the stable version. To fix it, simply make the two changes outlined by this commit.

Note 2: TensorFlow Datasets will require you to independently downloaded RESISC45. Please follow the instructions provided here

Training models

We use hydra configs for our training experiments. The configs can all be found here. To run an experiment, you can either to define a new experiment config which can be used to override the default configs. Alternatively, you can just overwrite some configs in the command. We provide a few sample training commands configs for clarity:

python train.py +experiment=ours                        % LG SimCLR
python train.py +experiment=vis_baseline                % SimCLR 
python train.py +experiment=vis_baseline model=simsiam  % SimSiam

Evaluation

We use two primary evaluations: linear probe using L-BFGS and few-shot evaluation. The configs for those evaluations can be found here.

Linear Probe: we train a single layer using logistic regression and sweep over regualizer weight values. We provide an implementation of logistic regression using PyTorch's L-BFGS, however, you can easily use scikit-learn's implementation by setting the use_sklearn flag in the evaluation configs. For datasets without a standard validation split, we randomly split the training set while maintaining the class distribution.

Few-Shot Evaluation: we also evaluate our frozen features on 5-shot, 5-way classification. The evaluation can be found here. We sample the training samples from the train/valid splits and the query samples for the test set.

The following commands can be used to evaluate checkpoints or baselines. For example, you can evaluate our model or the pretrained SimCLR checkpoint on all the datasets by running the following commands:

python evaluate.py model.name=lgssl_checkpoints model.checkpoint=lgsimclr dataset.name=all
python evaluate.py model.name=simclr dataset.name=all

Pre-trained Checkpoints

You can find all our pretrained checkpoints here. You should download them to data/checkpoints. Alternatively, you could just use hubconf to get the relevant checkpoint as shown in the code snippet below:

import torch
model = torch.hub.load("mbanani/lgssl", "lgsimclr")

For a list of released models, check hubconf.py

Citation

If you find this code useful, please consider citing:

@inproceedings{elbanani2022languageguided,
  title={{Learning Visual Representations via Language-Guided Sampling}},
  author={El Banani, Mohamed and Desai, Karan and Johnson, Justin},
  booktitle={CVPR},
  year={2023},
}

Acknowledgments

We thank Richard Higgins, Ashkan Kazemi, and Santiago Castro for many helpful discussions. We also thank David Fouhey, Ziyang Chen, Chenhao Zheng, and Fahad Kamran, and Dandan Shan for their feedback on early drafts. This project was funded under the Ford-UM Alliance partnership. We thank Alireza Rahimpour, Devesh Upadhyay, and Ali Hassani from Ford Research for their support and discussion.