The long-tailed data distribution problem in machine learning causes model bias since the majority of the data is concentrated in a few classes, leading to misclassifications in the minority classes. In medical datasets, this problem is particularly challenging as it can have serious consequences. To address this problem, we propose Balanced Medical Net (BMN), a novel approach that balances supervised contrastive learning by using class averaging and class complement to alleviate the problem of long-tailed distribution in medical datasets. In BMN, class averaging takes the average of instances of each class in a minibatch, thus reducing the contribution of head classes and emphasizing the importance of tail classes, and class complements are introduced to have all classes represented in the minibatch. We evaluate the effectiveness of our approach on two long-tailed medical datasets, ISIC2018 and APTOS2019, and found that it outperformed or matched the performance of state-of-the-art methods in terms of classification accuracy and F1-score. Our proposed method has the potential to improve diagnosis and treatment for patients with possibly fatal diseases and addresses an important issue in medical dataset.
This is a PyTorch implementation of our project:
We adopt the codebase of BCL.
Weights for best models can be found here.
Download the data related to the 3rd task (Training Data, Training Ground Truth, Validation Data, and Validation Ground Truth). The text files in data/ISIC2018
are the processed labels ready for the Dataset class to read.
Download the train images and the train.csv file from the kaggle competition. The text files that include the train/val split we adopted can be found here data/APTOS2019
.
Run the following command to install dependencies before running the code: pip install -r requirements.txt
dataset/
:dataset.py
: defines the dataset class used for both datasets
loss/
:contrastive.py
: includes the definition of theSCL
(supervised contrastive loss) andBalSCL
(balanced contrastive loss) classeslogitadjust.py
: includes the definition of the losses used in the classification branch, e.g.LogitAdjust
,FocalLoss
,EQLv2
, andLabelSmoothingCrossEntropy
models/
:resnext.py
: defines theBCLModel
, as well as theResNet
andResNeXt
backbones
randaugment.py
: includes the implementation ofAutoAugment
andRandAugment
main.py
: main file that is ran for trainingutils.py
: includes the definitions of some util functionstrain-isic.sh
: includes the command to run the training for the isic dataset with all the arguments that match our best experiment. The argumentsdata
,val_data
,txt
, andval_txt
are the paths to the training images, validation images, training labels, and validation labels. You need to specify these directories. Another argument that has to be specified isuser_name
which is the wandb username where the experiments will be logged.train-aptos.sh
: includes the command to run the training for the aptos dataset with all the arguments that match our best experiment. The argumentsdata
,val_data
,txt
, andval_txt
are the paths to the training images, validation images, training labels, and validation labels. You need to specify these directories. Another argument that has to be specified isuser_name
which is the wandb username where the experiments will be logged.