/RETFound_MAE_JCP

RETFound - A foundation model for retinal image

Primary LanguagePythonOtherNOASSERTION

RETFound - A foundation model for retinal imaging

Official repo for RETFound: a foundation model for generalizable disease detection from retinal images, which is based on MAE:

Please contact ykzhoua@gmail.com or yukun.zhou.19@ucl.ac.uk if you have questions.

Keras version implemented by Yuka Kihara can be found here

📝Key features

  • RETFound is pre-trained on 1.6 million retinal images with self-supervised learning
  • RETFound has been validated in multiple disease detection tasks
  • RETFound can be efficiently adapted to customised tasks

🎉News

🔧Install environment

  1. Create environment with conda:
conda create -n retfound python=3.7.5 -y
conda activate retfound
  1. Install dependencies
git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE
pip install -r requirement.txt

🌱Fine-tuning with RETFound weights

To fine tune RETFound on your own data, follow these steps:

  1. Download the RETFound pre-trained weights
ViT-Large
Colour fundus image download
OCT download
  1. Organise your data into this directory structure (Public datasets used in this study can be downloaded here)
├── data folder
    ├──train
        ├──class_a
        ├──class_b
        ├──class_c
    ├──val
        ├──class_a
        ├──class_b
        ├──class_c
    ├──test
        ├──class_a
        ├──class_b
        ├──class_c
  1. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
    --batch_size 16 \
    --world_size 1 \
    --model vit_large_patch16 \
    --epochs 50 \
    --blr 5e-3 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.2 \
    --nb_classes 5 \
    --data_path ./IDRiD_data/ \
    --task ./finetune_IDRiD/ \
    --finetune ./RETFound_cfp_weights.pth \
    --input_size 224

  1. For evaluation only (download data and model checkpoints here; change the path below)
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
    --eval --batch_size 16 \
    --world_size 1 \
    --model vit_large_patch16 \
    --epochs 50 \
    --blr 5e-3 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.2 \
    --nb_classes 5 \
    --data_path ./IDRiD_data/ \
    --task ./internal_IDRiD/ \
    --resume ./finetune_IDRiD/checkpoint-best.pth \
    --input_size 224

# This command evaluates a pre-trained Vision Transformer (ViT) model on a dataset using a single GPU.
# - `--eval` runs the evaluation mode without training.
# - `--batch_size 16` sets the batch size for evaluation.
# - `--world_size 1` specifies single-process (single-GPU) execution.
# - `--model vit_large_patch16` selects the ViT model variant.
# - `--epochs 50`, `--blr 5e-3`, `--layer_decay 0.65`, `--weight_decay 0.05`, and `--drop_path 0.2` are training parameters included by default but not actively used in evaluation.
# - `--nb_classes 5` configures the model for 5-class classification.
# - `--data_path ./IDRiD_data/` specifies the dataset directory.
# - `--task ./internal_IDRiD/` sets the task-specific directory for logging or configurations.
# - `--resume ./finetune_IDRiD/checkpoint-best.pth` loads a pre-trained model checkpoint for evaluation.
# - `--input_size 224` resizes input images to 224x224 pixels, matching the model's expected input size.

Load the model and weights (if you want to call the model in your code)

import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_

# call the model
model = models_vit.__dict__['vit_large_patch16'](
    num_classes=2,
    drop_path_rate=0.2,
    global_pool=True,
)

# load RETFound weights
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]

# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)

# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)

assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}

# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)

print("Model = %s" % str(model))

📃Citation

If you find this repository useful, please consider citing this paper:

@article{zhou2023foundation,
  title={A foundation model for generalizable disease detection from retinal images},
  author={Zhou, Yukun and Chia, Mark A and Wagner, Siegfried K and Ayhan, Murat S and Williamson, Dominic J and Struyven, Robbert R and Liu, Timing and Xu, Moucheng and Lozano, Mateo G and Woodward-Court, Peter and others},
  journal={Nature},
  volume={622},
  number={7981},
  pages={156--163},
  year={2023},
  publisher={Nature Publishing Group UK London}
}