Balanced Contrastive Representation Learning for Long-tailed Medical Image Classification

The long-tailed data distribution problem in machine learning causes model bias since the majority of the data is concentrated in a few classes, leading to misclassifications in the minority classes. In medical datasets, this problem is particularly challenging as it can have serious consequences. To address this problem, we propose Balanced Medical Net (BMN), a novel approach that balances supervised contrastive learning by using class averaging and class complement to alleviate the problem of long-tailed distribution in medical datasets. In BMN, class averaging takes the average of instances of each class in a minibatch, thus reducing the contribution of head classes and emphasizing the importance of tail classes, and class complements are introduced to have all classes represented in the minibatch. We evaluate the effectiveness of our approach on two long-tailed medical datasets, ISIC2018 and APTOS2019, and found that it outperformed or matched the performance of state-of-the-art methods in terms of classification accuracy and F1-score. Our proposed method has the potential to improve diagnosis and treatment for patients with possibly fatal diseases and addresses an important issue in medical dataset.

This is a PyTorch implementation of our project:

We adopt the codebase of BCL.

Weights for best models can be found here.

Datasets

Download the data related to the 3rd task (Training Data, Training Ground Truth, Validation Data, and Validation Ground Truth). The text files in data/ISIC2018 are the processed labels ready for the Dataset class to read.

Download the train images and the train.csv file from the kaggle competition. The text files that include the train/val split we adopted can be found here data/APTOS2019.

Dependencies

Run the following command to install dependencies before running the code: pip install -r requirements.txt

Code Structure

  • dataset/:
    • dataset.py: defines the dataset class used for both datasets
  • loss/:
    • contrastive.py: includes the definition of the SCL (supervised contrastive loss) and BalSCL (balanced contrastive loss) classes
    • logitadjust.py: includes the definition of the losses used in the classification branch, e.g. LogitAdjust, FocalLoss, EQLv2, and LabelSmoothingCrossEntropy
  • models/:
    • resnext.py: defines the BCLModel, as well as the ResNet and ResNeXt backbones
  • randaugment.py: includes the implementation of AutoAugment and RandAugment
  • main.py: main file that is ran for training
  • utils.py: includes the definitions of some util functions
  • train-isic.sh: includes the command to run the training for the isic dataset with all the arguments that match our best experiment. The arguments data, val_data, txt, and val_txt are the paths to the training images, validation images, training labels, and validation labels. You need to specify these directories. Another argument that has to be specified is user_name which is the wandb username where the experiments will be logged.
  • train-aptos.sh: includes the command to run the training for the aptos dataset with all the arguments that match our best experiment. The arguments data, val_data, txt, and val_txt are the paths to the training images, validation images, training labels, and validation labels. You need to specify these directories. Another argument that has to be specified is user_name which is the wandb username where the experiments will be logged.