This is an unofficial implementation of the paper "An Interpretable and Accurate Deep-Learning Diagnosis Framework Modeled With Fully and Semi-Supervised Reciprocal Learning" (IEEE TMI 2023), presenting a deep interpretable framework for diagnosing medical diseases, e.g., breast cancer in mammograms, retinopathy in optical coherence tomography (OCT) images, and brain tumor in magnetic resonance (MR) images.
The method integrates the interpretable prototype-based classifier (ProtoPNet) with existing deep global image classifier (GlobalNet), as shown in (a) above, which are optimised with a two-stage reciprocal student-teacher learning paradigm: (b) the student ProtoPNet learns from the suitable pseudo labels produced by the teacher GlobalNet and the GlobalNet is trained based on the ProtoPNet’s performance feedback; and (c) the teacher GlobalNet is further trained using the pseudo labels produced by the accurate student ProtoPNet.
The interpretable ProtoPNet branch, as shown above, classifies images by calculating similarities with class-specific image prototypes that are learned from training data.
We re-implement the paper with Pytorch 1.9.1+cu111 on 2 NVIDIA A40 GPUs, and our trained model weights (EfficientNet B0 as backbone) are provided concomitantly. These models are trained using our private mammogram database (about 67,000 mammography images).
- Pre-training the GlobalNet branch:
python pretrain_globalnet.py
- Training the ProtoPNet branch, using data samples optimally pseudo-labeled by the teacher GlobalNet:
python train_protopnet.py
- Retraining the GlobalNet branch, using pseudo labels of the accurate ProtoPNet:
python retrain_globalnet.py
The cancer localisation heatmap can be computed using:
python test_cancer_loc_heatmap.py
The authors give visualised prototypes in their paper, together with the corresponding source training images and self-activated similarity maps.
As illustrated by the authors, for a testing mammogram, the method classifies it as belonging to the cancer class because the abnormality present in the image looks more similar to the cancer prototypes than the non-cancer ones, as evidenced by the higher similarity scores with the cancer prototypes.
Please remember to respect the authors and cite their work properly:
@article{wang2023interpretable,
title={An Interpretable and Accurate Deep-learning Diagnosis Framework Modelled with Fully and Semi-supervised Reciprocal Learning},
author={Wang, Chong and Chen, Yuanhong and Liu, Fengbei and Elliott, Michael and Kwok, Chun Fung and Pe{\~n}a-Solorzano, Carlos and Frazer, Helen and McCarthy, Davis James and Carneiro, Gustavo},
journal={IEEE Transactions on Medical Imaging},
year={2023},
publisher={IEEE}
}
We refer to the repository below to implement this code:
Prototypical part network (ProtoPNet)
I'm happy that you want to further contribute and improve the code, any suggestion is welcome.