Are Vision Transformers Robust to Spurious Correlations?

This codebase provides a Pytorch implementation for the paper: Are Vision Transformers Robust to Spurious Correlations? .

Abstract

Deep neural networks may be susceptible to learning spurious correlations that hold on average but not in atypical test samples. As with the recent emergence of vision transformer (ViT) models, it remains underexplored how spurious correlations are manifested in such architectures. In this paper, we systematically investigate the robustness of vision transformers to spurious correlations on three challenging benchmark datasets and compare their performance with popular CNNs. Our study reveals that when pre-trained on a sufficiently large dataset, ViT models are more robust to spurious correlations than CNNs. Key to their success is the ability to generalize better from the examples where spurious correlations do not hold. Further, we perform extensive ablations and experiments to understand the role of the self-attention mechanism in providing robustness under spuriously correlated environments. We hope that our work will inspire future research on further understanding the robustness of ViT models.

Required Packages

Our experiments are conducted on Ubuntu Linux 20.04 with Python 3.9 and Pytorch 1.6. Besides, the following packages are required to be installed:

  • Scipy
  • Numpy
  • Sklearn
  • Pandas
  • tqdm
  • pillow
  • timm

Pre-trained Checkpoints

In our experiments, for ViT models we use the pre-trained checkpoints provided with the timm library. Pre-trained checkpoints for BiT models can be downloaded following the instructions in the official repo. Please download place the checkpoints for BiTs in bit_pretrained_models folder.

Datasets

In our study, we use the following challenging benchmarks :

  • WaterBirds: Similar to the construction in Group_DRO, this dataset is constructed by cropping out birds from photos in the Caltech-UCSD Birds-200-2011 (CUB) dataset (Wah et al., 2011) and transferring them onto backgrounds from the Places dataset (Zhou et al., 2017).
  • CelebA: Large-scale CelebFaces Attributes Dataset. The data we used for this task is listed in datasets/celebA/celebA_split.csv, and after downloading the dataset, please place the images in the folder of datasets/celebA/img_align_celeba/. datasets/celebA_dataset.py provides the dataloader for CelebA datasets and OOD datasets.
  • ColorMINST: A colour-biased version of the original MNIST Dataset.

Quick Start

To run the experiments, you need to first download and place the pretrained model checkpoints and datasets in the specificed folders as instructed in Pre-trained Checkpoints and Datasets. We provide the following commands and general descriptions for related files.

WaterBirds

  • datasets/waterbirds_dataset.py: provides the dataloader for Waterbirds dataset. The code expects the following files/folders in the [root_dir]/datasets directory:
  • waterbird_complete95_forest2water2/

You can download a tarball of this dataset from here. The Waterbirds dataset can also be accessed through the WILDS package, which will automatically download the dataset.

To train ViT model (ViT-B_16) on Waterbirds dataset, run the following command:

python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 500 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384

To train ViT model (ViT-S_16) on Waterbirds dataset, run the following command:

python train.py --name waterbirds_exp --model_arch ViT --model_type ViT-B_16 --dataset waterbirds --warmup_steps 100 --num_steps 700 --learning_rate 0.03 --batch_split 1 --img_size 384

Similarly, sample command to run BiT model on Watervirds dataset:

python train.py --name waterbirds_exp --model_arch BiT --model_type BiT-M-R50x1 --dataset waterbirds --learning_rate 0.003--batch_split 1 --img_size 384

Notes for some of the arguments:

  • --name: Name to identify the checkpoints. Users are welcome to use other names for convenience.
  • --model_arch : Model architecture to be used for training. Users need to specify ViT for Vision Transformers or BiT for Big-Transfer models.
  • --model_type : Model variant to be used for training. Please check the table below.
  • --warmup_steps : Specifies the number of warmup steps used for training ViT models. This is set as 500 for all ViT models.
  • --num_steps : Specifies the total number of global steps used for training ViT models. For ViT-S_16 and ViT-Ti_16, this is set as 1000 whereas for ViT-B_16 set this as 2000.
  • --batch_split: The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of batch_split.
Model model_arch model_type #params
ViT-B/16 ViT ViT-B_16 86.1 M
ViT-S/16 ViT ViT-S_16 21.8 M
ViT-Ti/16 ViT ViT-Ti_16 5.6 M
BiT-M-R50x3 BiT BiT-M-R50x3 211 M
BiT-M-R101x1 BiT BiT-M-R101x1 42.5 M
BiT-M-R50x1 BiT BiT-M-R50x1 23.5 M

To generate accuracy metrics for ViT model(ViT-S_16) on train and test data (worst-group accuracy), run the following command :

python evaluate.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint

Notes for some of the arguments:

  • --checkpoint_dir : Model checkpoint fine-tuned on waterbirds to be used for inference. If not provided, then it automatically searches for the model checkpoint in output/[name]/[model_arch]/[model_type] directory.

To generate consistency measure, users need to first download the evaluation dataset from here and place the images in [root_dir]/datasets/waterbird_bg directory. For ViT model (ViT-S_16), run the following command:

python waterbirds_consistency.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --checkpoint_dir model_checkpoint --batch_size 32

Spurious OOD evaluation

To generate the OOD dataset, users need to run datasets/generate_placebg.py which subsamples background images of specific types as the OOD data. You can simply run python generate_placebg.py to generate the OOD dataset, and it will be stored as datasets/ood_datasets/placesbg/. Note: Before the generation of OOD dataset, users need to download and change the path of CUB dataset and Places dataset.

To obtain spurious OOD evaluation for for ViT model (ViT-S_16), run the following command:

python ood_eval.py --name waterbirds_exp --model_arch ViT --model_type ViT-S_16 --id_dataset waterbirds --batch_size 64 --img_size 384 --checkpoint_dir model_checkpoint

References

Some parts of the codebase are adapted from GDRO, Spurious_OOD, big_transfer and ViT-pytorch.

For bibtex citation

@misc{ghosal2022vision,
      title={Are Vision Transformers Robust to Spurious Correlations?}, 
      author={Soumya Suvra Ghosal and Yifei Ming and Yixuan Li},
      year={2022},
      eprint={2203.09125},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}