Repository of the paper "Language models are good pathologists: using attention-based sequence reduction and text-pretrained transformers for efficient WSI classification"
Required packages:
- The models' implementation only require
pytorch
. We use language models from thetransformers
library. - The packages required for training are
pytorch-lightning
,hydra-zen
,torchmetrics
. conda env create -f env.yaml
will create a conda environmentlmagp-env
with the required dependencies
The methods proposed in the paper are available in src/models/adapted_transformers.py
. A frozen pre-trained RoBERTa-base with a SeqShort layer can be instantiated in the following way:
import torch
from transformers import AutoModelForSequenceClassification
from src.models.adapted_transformers import AdaptedModel, MHASequenceShortenerWithLN, freeze_model
lm_classifier = freeze_model(AutoModelForSequenceClassification.from_pretrained('roberta-base', num_labels=2)) # will freeze the encoder parameters except for the layer norm layers.
seq_shortener = MHASequenceShortenerWithLN(target_len=256, embed_dim=768, kdim=1280, vdim=1280, num_heads=4, batch_first=True) # kdim, vdim: hidden dim of efficientnet v2 l
adapted_lm = AdaptedModel(model=lm_classifier, seq_shortener=seq_shortener, embed_dim=768)
x = torch.rand([1,5000,1280]) # [batch_size, num_tiles, feature_extractor_hidden_dim]
y = adapted_lm(x) # y['attentions'][0] will have the attention matrix of the SeqShort layer
The models should be easily plugged to an existing Multiple Instance Learning pipeline.
We require the WSIs to be preprocessed previously. Our implementation relies on each sample being a dict
(saved as a .pickle file) with a features
key, whose value is a torch.Tensor
of shape [num_tiles, feature_extractor_hidden_dim]
. For example for a WSI comprising 100 tiles whose features were obtained with EfficientNetV2-L, the features
field should be a tensor of shape [100, 1280]
.
The csvs
directory contain the TCGA-BRCA splits that were used in our study.
We use Hydra and PyTorch-Lightning for training, and every hyperparameter is configurable from .yaml config files. We provide an example config, where it is only needed to specify the correct path to the 10-fold .csvs root, and the fold number for which we want to train, validate and test.
# conda env create -f env.yaml
conda activate lmagp-env
python3 train_classifier.py -cn configs -cn seq-short-roberta-base.yaml ++csvs_root=/path/to/the/csvs ++fold=0