/attention_maps

Primary LanguagePythonMIT LicenseMIT

Attention Maps: A Solution to AI Medical Imaging Interpretability?

This repository contains the implementation of the project "Are attention maps a solution to the interpretability crisis in AI-assisted medical imaging?". In this project, we investigated the reliability of attention maps of the Vision Transformer (ViT) as an interpretability method under the context of medical image diagnosis. This repository provides Python codes to reproduce interpretability results of attention maps and two different methods: Grad-CAM and the Chefer method.

intermap

Requirements

We recommend to use Python virtual environmnet using Conda.
You can install the required dependencies:

conda env create -f requirements.yaml

Code Usage & Reproduction of the Results

You can run main.py to generate interpretability maps and final evaluation results over all datasets and interpretability methods.

The structure of this repository is as follows:

  • Code - src/ folder contains necessary python files to generate interpretability maps (including attention maps) and evaluate different interpretability methods.

  • Data - data/ folder contains placeholding directories only without any image data. You may follow the instruction to download medical image datasets that we used.

  • Checkpoints - weights/ folder is a placeholding directory as well. .pth files of our selected trained models and MAE & DINO pre-trained ViT weights should be put in this directory.

Training

To evaluate interpretability methods, ViT-B/16 models are trained to achieve the state-of-the-art benchmark performance for each medical imaging dataset. 4 different initalization methods are used: random (i.e., from scratch), supervised, DINO, and MAE. Note that they are pretrained on ImageNet, except for the random initalization.

Hyperparameter Settings

The table below shows the hyperparameter settings used to train each model of this project. Saved epoch means the epoch of model selection in which early stopping based on validation loss happened. You can use these to reproduce our training results.

Dataset Model init. Optimizer Learning Rate Momentum Weight Decay Scheduler Total Epoch Saved Epoch Batch Size
CP-Child Random Adam 0.00001 0.9 0.0001 CosineAnnealing 20 18 64
Sup Adam 0.00001 0.9 0.0001 CosineAnnealing 20 5 64
DINO Adam 0.00001 0.9 0.0001 CosineAnnealing 20 12 64
MAE Adam 0.00001 0.9 0.0001 CosineAnnealing 20 11 64
DUKE Random SGD 0.1 0.3 0 CosineAnnealing 200 79 64
Sup SGD 0.03 0.1 0 CosineAnnealing 20 2 64
DINO SGD 0.005 0.1 0 CosineAnnealing 15 5 64
MAE SGD 0.03 0.1 0 CosineAnnealing 15 2 64
Kvasir Random SGD 0.08 0.1 0 CosineAnnealing 25 24 32
Sup SGD 0.0018 0.1 0 CosineAnnealing 25 8 32
DINO SGD 0.0012 0.1 0.0001 CosineAnnealing 25 20 32
MAE SGD 0.02 0.1 0.0001 CosineAnnealing 25 25 32
MURA Random SGD 0.001 0.9 0.0001 CosineAnnealing 200 180 64
Sup SGD 0.001 0.9 0.0001 CosineAnnealing 100 9 64
DINO SGD 0.001 0.9 0.0001 CosineAnnealing 150 43 64
MAE SGD 0.001 0.9 0.0001 CosineAnnealing 100 35 64

Checkpoints

This is another option to reproduce our interpretability results. You can just directly download the checkpoints, put them into weights/ folder, and load them to run main.py.

Dataset / Method Random Supervised DINO MAE
CP-Child Download Download Download Download
DUKE Download Download Download Download
Kvasir Download Download Download Download
MURA Download Download Download Download

References

Chefer, H., Gur, S., & Wolf, L. (2021). Transformer Interpretability Beyond Attention Visualization. Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 782–791. https://doi.org/10.1109/CVPR46437.2021.00084

Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2020). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. arxiv. https://doi.org/10.48550/arXiv.2010.11929

Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. (2017). Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization. Proceedings of the IEEE International Conference on Computer Vision, 618–626. https://doi.org/10.1109/ICCV.2017.74