/Whac-A-Mole

Code and datasets for the paper "A Whac-A-Mole Dilemma Shortcuts Come in Multiples Where Mitigating One Amplifies Others" (CVPR 2023)

Primary LanguagePythonOtherNOASSERTION

A Whac-A-Mole Dilemma: Shortcuts Come in Multiples Where Mitigating One Amplifies Others (CVPR 2023)

Zhiheng Li, Ivan Evtimov, Albert Gordo, Caner Hazirbas, Tal Hassner, Cristian Canton Ferrer, Chenliang Xu, Mark Ibrahim

[paper]


TL; DR: Our benchmark results on UrbanCars and ImageNet reveal the overlooked Whac-A-Mole dilemma in shortcut mitigation, i.e., mitigating one shortcut amplifies the reliance on other shortcuts.

ImageNet-W

We discover the new watermark shortcut in ImageNet. and create ImageNet-W test set to study (1) state-of-the-art vision models' reliance on the watermark shortcut; (2) the reliance on multiple shortcuts on ImageNet when using ImageNet-W along with other out-of-distribution variants of ImageNet (e.g., ImageNet-R).

Use ImageNet-W

  • Install ImageNet-W package:
pip install imagenet-w
  • Use AddWatermark transform for ImageNet:
from imagenet_w import AddWatermark
from torchvision.datasets import ImageNet

resize_size = 256
crop_size = 224

test_transform = transforms.Compose(
      [
          transforms.Resize(resize_size),
          transforms.CenterCrop(crop_size),
          transforms.ToTensor(),
          AddWatermark(crop_size),  # insert AddWatermark before normalize
          normalize,
      ]
  )

imagenet_w = ImageNet(root, split="val", transform=test_transform)

Requirements

pip install -r requirements.txt

UrbanCars Experiments

We construct UrbanCars dataset, a new dataset with multiple shortcuts (i.e., background and co-occurring object), facilitating the study of multi-shortcut learning under the controlled setting.

Generate UrbanCars Dataset

bash scripts/prepare_dataset_models/create_urbancars.sh

Train Shortcut Mitigation Methods on UrbanCars

Use shell scripts in scripts/train_urbancars to run each method, e.g.,:

bash scripts/train_urbancars/$METHOD.sh

where $METHOD should be replaced by method names listed in scripts/train_urbancars.


ImageNet Experiments

Prepare ImageNet and its out-of-distribution variants

See prepare_ImageNet.md

Prepare models for evaluating shortcut reliance

See prepare_checkpoints_for_eval.md

Evaluate state-of-the-art vision models' watermark shortcut reliance

PYTHONPATH=.:$PYTHONPATH python eval_shortcuts/eval_watermark_shortcut.py

Evaluate reliance on multiple shortcuts

PYTHONPATH=.:$PYTHONPATH python eval_shortcuts/eval_multiple_shortcuts.py

Training

We use last layer retraining for ImageNet experiments.

PYTHONPATH=.:$PYTHONPATH python imagenet_trainers/launcher.py --method ${METHOD} --amp --feature_extractor resnet50_erm --lr ${LR} [--wandb] [--slurm_partition ${SLURM_PARTITION}] [--slurm_job_name ${METHOD}_imagenet]
  • ${METHOD} is the method name (check options for --method by PYTHONPATH=.:$PYTHONPATH python imagenet_trainers/launcher.py --help).

  • [Optional] Use --slurm_partition ${SLURM_PARTITION} and --slurm_job_name ${METHOD}_imagenet when training on a Slurm cluster, where we use submitit to submit jobs.

  • [Optional] Turn on --wandb to use wandb for logging.

  • --feature_extractor resnet50_erm means using ResNet-50 trained with ERM to as the feature extractor. For our porposed Last Layer Ensemble (LLE) method, we also use vit-b_mae-ft, vit-l_mae-ft, vit-h_mae-ft (ViT architecture with finetuned MAE) , and vit-b_swag-ft (ViT-B architecture with finetuned SWAG).

  • ${LR} is the learning rate. We tune learning rates based on IN-1k top-1 accuracy. The learning rates after tuning and the checkpoints are shown in the following table:

method architecture IN-1k IN-W Gap Carton Gap SIN Gap IN-R Gap IN-9 Gap LR download
ERM ResNet-50 76.39 -25.40 +30 -69.43 -56.22 -5.19 1e-3 model
Mixup ResNet-50 76.17 -24.87 +34 -68.18 -55.79 -5.60 1e-4 model
CutMix ResNet-50 75.90 -25.78 +32 -69.31 -56.36 -5.65 1e-4 model
Cutout ResNet-50 76.40 -25.11 +32 -69.39 -55.93 -5.35 1e-3 model
AugMix ResNet-50 76.23 -23.41 +38 -68.51 -54.91 -5.85 1e-4 model
SD ResNet-50 76.39 -26.03 +30 -69.42 -56.36 -5.33 1e-3 model
WTM Aug ResNet-50 76.32 -5.78 +14 -69.31 -56.22 -5.34 1e-3 model
TXT Aug ResNet-50 75.94 -25.93 +36 -63.99 -53.24 -5.66 1e-4 model
BG Aug ResNet-50 76.03 -25.01 +36 -68.41 -54.51 -4.67 1e-4 model
LfF ResNet-50 76.35 -26.19 +36 -69.34 -56.02 -5.61 1e-4 model
JTT ResNet-50 76.33 -26.40 +32 -69.48 -56.30 -5.55 1e-2 model
EIIL ResNet-50 71.51 -33.17 +24 -65.93 -61.09 -6.27 1e-4 model
DebiAN ResNet-50 76.33 -26.40 +36 -69.37 -56.29 -5.53 1e-4 model
LLE (ours) ResNet-50 76.25 -6.18 +10 -61.00 -54.89 -3.82 1e-3 model
MAE + LLE (ours) ViT-B 83.68 -2.48 +6 -58.78 -44.96 -3.70 1e-3 model
MAE + LLE (ours) ViT-L 85.84 -1.74 +12 -56.32 -34.64 -2.77 1e-3 model
MAE + LLE (ours) ViT-H 86.84 -1.11 +28 -55.69 -30.95 -2.35 1e-3 model
SWAG + LLE (ours) ViT-B 85.37 -2.50 +8 -60.92 -28.37 -3.19 1e-4 model

In our proposed Last Layer Ensemble (LLE) method, we also use edge detection for data augmentation, i.e., Edge Aug. The details of how to generate edge detection data on ImageNet and the checkpoints are in Edge_Aug.md.

Evaluation

To evaluate the trained models, download the checkpoint from the table above and use its file path as ${PATH_TO_CHECKPOINT}:

PYTHONPATH=.:$PYTHONPATH python imagenet_trainers/launcher.py --method ${METHOD} --amp --feature_extractor resnet50_erm [--wandb] [--slurm_partition ${SLURM_PARTITION}] [--slurm_job_name ${METHOD}_imagenet] --evaluate --resume ${PATH_TO_CHECKPOINT}

Citation

If you use UrbanCars dataset or ImageNet-W dataset, or compare with our proposed Last Layer Ensemble (LLE) method, please cite our paper:

@InProceedings{Li_2023_CVPR_Whac_A_Mole,
    author    = {Li, Zhiheng and Evtimov, Ivan and Gordo, Albert and Hazirbas, Caner and Hassner, Tal and Ferrer, Cristian Canton and Xu, Chenliang and Ibrahim, Mark},
    title     = {A Whac-a-Mole Dilemma: Shortcuts Come in Multiples Where Mitigating One Amplifies Others},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {20071-20082}
}

License

See LICENSE for details.

Attribution

The Whack-A-Mole icon is created by Flat Icons - Flaticon