BEiT-based model Pre-training on Histopathological image

Official repo for BEPH, which is based on BEiTv2:

*It is worth noting that the BEiT implementation we use comes from mmselfsup[https://github.com/open-mmlab/mmselfsup].


Key Features

This is the repo for the paper A foundation model for generalizable cancer diagnosis and survival prediction from histopathological images led by ZhaochangYang and TingWei:

  • BEPH is pre-trained on 11 million histopathological images from TCGA with self-supervised learning
  • BEPH has been validated in multiple cancer detection and survival prediction tasks
  • BEPH can be efficiently adapted to customised tasks

Install environment

Install mmselfsup

conda create -n BEPH python=3.9 -y
conda activate BEPH
conda install pytorch torchvision -c pytorch
git clone https://github.com/Zhcyoung/BEPH_new.git
pip install -U openmim
mim install mmengine
mim install 'mmcv>=2.0.0'
cd mmclassification && mim install -e .
cd .. && cd mmselfsup && mim install -e .

Extract backbone weights to apply to downstream tasks, or download the weight directly [],[]:

import torch

ck = torch.load("./BEPH_weight.pth", map_location=torch.device('cpu'))
outPath = "./BEPH_backbone.pth"
output_dict = dict(state_dict=dict(), author='Yzc')
has_backbone = False
for key, value in ck['state_dict'].items():
    if key.startswith('backbone'):
        output_dict['state_dict'][key] = value
        has_backbone = True

if not has_backbone:
    raise Exception('Cannot find a backbone module in the checkpoint.')
torch.save(output_dict, outPath)

Downloading + Preprocessing + Organizing TCGA Data

We downloaded diagnostic whole-slide images (WSIs) for 32 cancer types using the GDC Data Transfer Tool, and then we locally sample image regions of 1024×224×224 (approximately 1024 images) from each pathological image, ensuring that the sampled region has a tissue proportion greater than 75%. These sampled image regions are then cropped into 224×224 tiles at 40X magnification, while maintaining a tissue proportion of 75%.


For pre-training,each cancer type is organized as its own folder in TCGA_ROOT_DIR, which additionally contains the following subfolders:

And generate a pre-train.txt containing the filename:

And then modify the pre-train config file:beitv2_vit.py

train_dataloader = dict(
        ann_file= 'pre-train.txt' , ###Change to your pre-training file
Pre-training Command:

bash tools/slurm_train_4gpu.sh a100 BEPH  ./TrainConfigs/beitv2_vit.py

Fine-tuning with BEPH weights

To fine tune BEPH on your own data, follow these steps:

  1. Download the BEPH pre-trained weights, Google Drive, baidu

Patch level tasks

Start fine-tuning (use BreakHis as example). A fine-tuned checkpoint will be saved during training.

Organise your data into this directory structure:

├── data folder

Train.txt /val.txt

./images/SOB_M_DC-14-16716-100-022.png 1
./images/SOB_B_TA-14-16184CD-100-003.png 0
./images/SOB_B_TA-14-16184CD-100-031.png 0
./images/SOB_M_DC-14-14946-100-025.png 1
./images/SOB_M_LC-14-12204-100-037.png 1
Train Command:

bash ./tools/benchmarks/classification/mim_dist_train.sh  ./FineTuning/beit.py  ./BEPH_backbone.pth

For evaluation (download data and model checkpoints here; change the path below):

bash ./tools/benchmarks/classification/mim_dist_test.sh   ./FineTuning/beit.py ./work_dir/epoch_x.pth

wsi level tasks:

Following pretraining and pre-extracting instance-level features using ViT-base, we use the publicly-available CLAM scaffold code as well as several of the current weakly-supervised baselines for running 10-fold monte carlo cross-validation experiments.

Directory tree:

	├── slide_1.svs
	├── slide_2.svs
	└── ...
	├── masks
		├── slide_1.jpg
		└── ...
	├── patches
		├── slide_1.h5
		└── ...
	├── stitches
		├── slide_1.jpg
		└── ...
	├── process_list_autogen.csv
	└── Step_2.csv
	├── h5_files
		├── slide_1.h5
		└── ...
	└── pt_files
		├── slide_1.pt
		└── ...
	├── splits_0.csv
	└── ...
	├── tcga_brca_subtype
		├── s_0_checkpoint.pt
		├── splits_0.csv
		├── ...
		└──	summary.csv
	└── ...

Feature extraction:

python create_patches_fp.py \
--source ./DATA_DIRECTORY/  \
--save_dir ./PATCH_DIRECTORY/patch_splits \
--patch_size 224 \
--seg \
--patch \
import os 
import pandas as pd 

df = pd.read_csv('./PATCH_DIRECTORY/process_list_autogen.csv') # This csv is generated in the first step
ids1 = [i[:-4] for i in df.slide_id]
ids2 = [i[:-3] for i in os.listdir('./PATCH_DIRECTORY/patch_splits/patches/')]
df['slide_id'] = ids1
ids = df['slide_id'].isin(ids2)

Get feature: histopathological image DINO feature

# ImageNet ResNet-50 feature: extract_features_fp.py
#histopathological image DINO feature: extract_features_dino.py
#BEPH feature: extract_features_BEPH.py

python extract_features_BEPH.py \ 
--data_h5_dir ./FEATURE_DIRECTORY/patch_splits/ \
--data_slide_dir ./DATA_DIRECTORY/ \
--csv_path ./PATCH_DIRECTORY/patch_splits/Step_2.csv \
--batch_size 2000 \
--slide_ext .svs

Filter out the slides that cannot extract features:

df = pd.read_csv(wsi_path[:-3]+'dataset_csv/label.csv')
df = df[['case_id','slide_id','slide_name','oncotree_code']]
ids1 = [i for i in df.slide_name]
ids2 = [i[:-3] for i in os.listdir(wsi_path[:-3]+'test_time_FEATURES_DIRECTORY/pt_files')]
ids = df['slide_name'].isin(ids2)
df = df.loc[ids]
df.columns = ['case_id','slide_id','slide_name','label']

Train Command (Take the clam_sb model for breast cancer subtypes classification as an example):

%run CLAM_SB_BEPH.py \
--data_root_dir   DATA_DIRECTORY/ \
--model_type   clam_sb \
--task tcga_brca_subtype \
--splits  SPLITS/ \
--lr 2e-4 \
--seed 123 \
--feature_path  FEATURES_DIRECTORY/
--csv_path DATASET_CSV/datasets.csv \
--k 10 \
--k_start 0 \
--results_dir  RESULTS/tcga_brca_subtype

For evaluation:

python eval.py --data_root_dir  DATA_DIRECTORY/ \
--model_type clam_sb \
--task tcga_brca_subtype \
--splits  SPLITS/ \
--feature_path FEATURES_DIRECTORY/ \
--weights_path ../weights/tcga_brca_subtype/ \
--csv_path DATASET_CSV/label.csv \
--k 10 \
--k_start 0 \
--results_dir RESULTS/tcga_brca_subtype

Analagously, we also extend the CLAM scaffold code for survival prediction, and make available:

Train Command :

python ./survival/CLAM_survival_BEPH.py --data_root_dir DATA_DIRECTORY/ \
--model_type clam_sb \
--task tcga_crc_subtype \
--max_epoch 20 \
--k 5 \
--k_start 0 \
--lr  2e-4 \
--seed 123 \
--results_dir ./RESULTS/tcga_crc_survival\
For evaluation:

python ./survival/eval_survival.py --data_root_dir DATA_DIRECTORY/ \
--model_type clam_sb \
--task tcga_crc_subtype \
--results_dir ./RESULTS/tcga_crc_survival/test