Official repo for RETFound: a foundation model for generalizable disease detection from retinal images, which is based on MAE:
Please contact ykzhoua@gmail.com or yukun.zhou.19@ucl.ac.uk if you have questions.
Keras version implemented by Yuka Kihara can be found here
- RETFound is pre-trained on 1.6 million retinal images with self-supervised learning
- RETFound has been validated in multiple disease detection tasks
- RETFound can be efficiently adapted to customised tasks
- 🐉2024/01: Feature vector notebook are now online!
- 🐉2024/01: Data split and model checkpoints for public datasets are now online!
- 🎄2023/12: Colab notebook is now online - free GPU & simple operation!
- 2023/09: a visualisation demo is added
- 2023/10: change the hyperparameter of input_size for any image size
- Create environment with conda:
conda create -n retfound python=3.7.5 -y
conda activate retfound
- Install dependencies
git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE
pip install -r requirement.txt
To fine tune RETFound on your own data, follow these steps:
- Download the RETFound pre-trained weights
ViT-Large | |
---|---|
Colour fundus image | download |
OCT | download |
- Organise your data into this directory structure (Public datasets used in this study can be downloaded here)
├── data folder
├──train
├──class_a
├──class_b
├──class_c
├──val
├──class_a
├──class_b
├──class_c
├──test
├──class_a
├──class_b
├──class_c
- Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
--batch_size 16 \
--world_size 1 \
--model vit_large_patch16 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD_data/ \
--task ./finetune_IDRiD/ \
--finetune ./RETFound_cfp_weights.pth \
--input_size 224
- For evaluation only (download data and model checkpoints here; change the path below)
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
--eval --batch_size 16 \
--world_size 1 \
--model vit_large_patch16 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD_data/ \
--task ./internal_IDRiD/ \
--resume ./finetune_IDRiD/checkpoint-best.pth \
--input_size 224
# This command evaluates a pre-trained Vision Transformer (ViT) model on a dataset using a single GPU.
# - `--eval` runs the evaluation mode without training.
# - `--batch_size 16` sets the batch size for evaluation.
# - `--world_size 1` specifies single-process (single-GPU) execution.
# - `--model vit_large_patch16` selects the ViT model variant.
# - `--epochs 50`, `--blr 5e-3`, `--layer_decay 0.65`, `--weight_decay 0.05`, and `--drop_path 0.2` are training parameters included by default but not actively used in evaluation.
# - `--nb_classes 5` configures the model for 5-class classification.
# - `--data_path ./IDRiD_data/` specifies the dataset directory.
# - `--task ./internal_IDRiD/` sets the task-specific directory for logging or configurations.
# - `--resume ./finetune_IDRiD/checkpoint-best.pth` loads a pre-trained model checkpoint for evaluation.
# - `--input_size 224` resizes input images to 224x224 pixels, matching the model's expected input size.
import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
# call the model
model = models_vit.__dict__['vit_large_patch16'](
num_classes=2,
drop_path_rate=0.2,
global_pool=True,
)
# load RETFound weights
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
print("Model = %s" % str(model))
If you find this repository useful, please consider citing this paper:
@article{zhou2023foundation,
title={A foundation model for generalizable disease detection from retinal images},
author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others},
journal={Nature},
volume={622},
number={7981},
pages={156--163},
year={2023},
publisher={Nature Publishing Group UK London}
}