This is the official pytorch implementation of the paper:
Huang, Y., Lin, L., Cheng, P., Lyu, J. and Tang, X., 2021, September. Lesion-Based Contrastive Learning for Diabetic Retinopathy Grading from Fundus Images. In International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 113-123). Springer, Cham. [link] [arxiv]
Two publicly-accessible datasets are used in this work.
- IDRiD for training lesion detection network [link].
- EyePACS for contrastive learning and evaluation [link].
A wonderful object detection toolbox MMDetection are used for lesion detection. A trained model and predicted results can be downloaded here. Note that the model has a relatively poor generalization ability and cannot precisely predict lesions of fundus images from EyePACS because of the limited training samples of IDRiD. If you want to train your own detection model, please follow the instruction here.
Recommended environment:
- python 3.8+
- pytorch 1.5.1
- torchvision 0.6.1
- tensorboard 2.2.1
- tqdm
To install the dependencies, run:
$ git clone https://github.com/YijinHuang/Lesion-based-Contrastive-Learning.git
$ cd Lesion-based-Contrastive-Learning
$ pip install -r requirements.txt
- Download EyePACS dataset. Then use
tools/crop.py
to remove the black border of images and resize them to 512 x 512. - Rename all images as 'id_eyeSide.jpeg', where 'id' here is the id of images given by EyePACS and 'eyeSide' is left or right. Then move all images into a folder.
- Download the provided lesion predictions, which is a pickle file containing a dict as follows:
predictions = {
'train': {
'id_eyeSide.jpeg': [(x1, y1, x2, y2), ..., (x1, y1, x2, y2)],
'id_eyeSide.jpeg': [(x1, y1, x2, y2), ..., (x1, y1, x2, y2)],
'id_eyeSide.jpeg': [(x1, y1, x2, y2), ..., (x1, y1, x2, y2)],
...
},
'val': {
'id_eyeSide.jpeg': [(x1, y1, x2, y2), ..., (x1, y1, x2, y2)],
'id_eyeSide.jpeg': [(x1, y1, x2, y2), ..., (x1, y1, x2, y2)],
...
}
}
- Update 'data_path' and 'data_index' in
config.py
, where 'data_path' is the folder containing preprocessed images and 'data_index' is the pickle file containing lesion predicted results. You can update other training configurations and hyper-parameters inconfig.py
for your customized dataset. - Run to train:
$ CUDA_VISIBLE_DEVICES=x python main.py
where 'x' is the id of your GPU.
- You can monitor the training progress in website 127.0.0.1:6006 by running:
$ tensorborad --logdir=/path/to/your/log --port=6006
- All trained models are stored in 'save_path' in
config.py
. The default path is './checkpoints'. Our final trained models on EyePACS can be downloaded here.
A 2D image classification framework pytorch-classification is adopted to perform linear evaluation and transfer capacity evaluation. Please follow the instruction in that repository for evaluation. The model fine-tuned on the full training set (kappa of 0.8322 on the test set) can be downloaded here. The training configurations can be found in our other paper.
The model fine-tuned on the full EyePACS dataset and the models trained by lesion-based contrastive learning are provided here. TorchVision is utilized to build the model, so please install the required packages with the version in the requrements.txt
to avoid unexpected errors.
To load the fine-tuned model:
import torch
import torch.nn as nn
from torchvision import models
weights = torch.load('resnet50_128_08_100.pt')
model = models.resnet50()
# Our model outputs the score of DR for classification. See https://arxiv.org/pdf/2110.14160.pdf for more details.
model.fc = nn.Linear(model.fc.in_features, 1)
model.load_state_dict(weights, strict=True)
To load the the models trained by lesion-based contrastive learning:
import torch
from torchvision import models
weights = torch.load('resnet50_128_08.pt')
model = models.resnet50()
# Weights of fully connected layer are removed in the file, so set strict to be False.
model.load_state_dict(weights, strict=False)
- Update the configurations for training the DR grading network.
- The lesion-based contrastive learning model trained on the full EyePACS dataset.
- Build dataset using better lesion detection models.
Thanks for SupContrast for the implementation of contrastive loss, kaggle team o_O for the preprocessing code for fundus images.
@inproceedings{huang2021lesion,
title={Lesion-Based Contrastive Learning for Diabetic Retinopathy Grading from Fundus Images},
author={Huang, Yijin and Lin, Li and Cheng, Pujin and Lyu, Junyan and Tang, Xiaoying},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={113--123},
year={2021},
organization={Springer}
}