/vit-breast-cancer

Transfer learning pretrained vision transformers for breast histopathology

Primary LanguagePythonMIT LicenseMIT

Visualizing Transformers for Breast Histopathology

This repository contains code for Visualizing Transformers for Breast Histopathology. This work was completed as part of CPSC 482: Current Topics in Applied Machine Learning.

Abstract

Transfer learning is a common way of achieving high performance on downstream tasks with limited data. Simultaneously, the success of vision transformers has opened up a wider range of image model options than previously available. In this report, we explore the application of transfer learning in the context of breast histopathology using state-of-the-art vision transformer models: ViT, BeiT, and CaiT. We focus on ways of presenting model prediction and behavior in human-interpretable ways, such that a pathologist could leverage this information to aid with their diagnosis. Through experiments, we show how attention maps and latent representations can be used to interpret model behavior.

Quickstart

  1. Clone the repository.
$ git clone https://github.com/jaketae/vit-breast-cancer.git
  1. Create a Python virtual enviroment and install package requirements.
$ cd vit-breast-cancer
$ python -m venv venv
$ pip install -r requirements.txt
  1. To train a model, run python train.py; for evaluation, python evaluate.py with appropriate flags. For instance,
$ CUDA_VISIBLE_DEVICES=1 python evaluate.py --device cuda --checkpoint checkpoints/vit_freeze

Dataset

We used the Breast Histopathology Images dataset. You can either download the dataset directly from the website, or use Kaggle's Python API to download it via the command line. For detailed instructions on how to use the Kaggle API, refer to the documentation.

$ kaggle datasets download paultimothymooney/breast-histopathology-images

Create a subfolder within the directory, such as raw, then unzip the dataset via

$ unzip breast-histopathology-images.zip -d raw

Training

To evaluate a model checkpoint, run train.py. The full list of supported arguments are shown below.

$ python train.py -h
usage: train.py [-h] [--name NAME] [--device DEVICE] [--log_path LOG_PATH] [--data_path DATA_PATH] [--save_path SAVE_PATH]
                [--model MODEL] [--freeze FREEZE] [--epochs EPOCHS] [--lr LR] [--classifier_lr CLASSIFIER_LR] [--split SPLIT]
                [--threshold THRESHOLD] [--batch_size BATCH_SIZE] [--num_workers NUM_WORKERS]

optional arguments:
  -h, --help            show this help message and exit
  --name NAME
  --device DEVICE
  --log_path LOG_PATH
  --data_path DATA_PATH
  --save_path SAVE_PATH
  --model MODEL
  --freeze FREEZE
  --epochs EPOCHS
  --lr LR
  --classifier_lr CLASSIFIER_LR
  --split SPLIT
  --threshold THRESHOLD
  --batch_size BATCH_SIZE
  --num_workers NUM_WORKERS

Default configurations are specified in config.py.

Running this command will create a folder under checkpoints and logs according to the name field specified in the configuration file. checkpoints will contain model weights, and logs will contain tensorboard logs for model training inspection.

Evaluation

To evaluate a model checkpoint, run evaluate.py. The full list of supported arguments are shown below.

$ python evaluate.py -h
usage: evaluate.py [-h] [--device DEVICE] [--checkpoint CHECKPOINT]

optional arguments:
  -h, --help            show this help message and exit
  --device DEVICE
  --checkpoint CHECKPOINT

For example, assuming a vit_freeze checkpoint and a CUDA-enabled local machine, run

$ CUDA_VISIBLE_DEVICES=1 python evaluate.py --device cuda --checkpoint checkpoints/vit_freeze

If CUDA is not available, you can also set the device flag to cpu.

References

License

Released under the MIT License.