NIH-Chest-X-rays-Multi-Label-Image-Classification-In-Pytorch

Multi-Label Image Classification of the Chest X-Rays In Pytorch

Requirements

  • torch >= 0.4
  • torchvision >= 0.2.2
  • opencv-python
  • numpy >= 1.7.3
  • matplotlib
  • tqdm

Dataset

NIH Chest X-ray Dataset is used for Multi-Label Disease Classification of of the Chest X-Rays. There are a total of 15 classes (14 diseases, and one for 'No findings') Images can be classified as "No findings" or one or more disease classes:

  • Atelectasis
  • Consolidation
  • Infiltration
  • Pneumothorax
  • Edema
  • Emphysema
  • Fibrosis
  • Effusion
  • Pneumonia
  • Pleural_thickening
  • Cardiomegaly
  • Nodule Mass
  • Hernia

There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing.

Sample X-Ray Images

Atelectasis
Cardiomegaly | Edema | Effusion
No Finding

Model

Pretrained Resnet50 model is used for Transfer Learning on this new image dataset.

Loss Function

There is a choice of loss function

  • Focal Loss (default)
  • Binary Cross Entropy Loss or BCE Loss

Training

  • From Scratch

    Following are the layers which are set to trainable-

    • layer2
    • layer3
    • layer4
    • fc

    Terminal Code:

    python main.py
    
  • Resuming From a Saved Checkpoint

    A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the

    • epochs (number of epochs the model has been trained till that time)

    • model (architecture and the learnt weights of the model)

    • lr_scheduler_state_dict (state_dict of the lr_scheduler)

    • losses_dict (a dictionary containing the following loses)

      • mean train epoch losses for all the epochs
      • mean val epoch losses for all the epochs
      • batch train loss for all the training batches
      • batch train loss for all the val batches

Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The 'stage' parameter can be passed from the terminal using the argument --stage STAGE

Terminal Code:

python main.py --resume --ckpt checkpoint_file.pth --stage 2

Training the model will create a models directory and will save the checkpoints in there.

Testing

A Saved Checkpoint needs to be loaded using the --ckpt argument and --test argument needs to be passed for activating the Test Mode

Terminal Code:

python main.py --test --ckpt checkpoint_file.pth

Result

The model achieved the average ROC AUC Score of 0.73241 on all classes(excluding "No findings" class) after training in the following stages-

STAGE 1

  • Loss Function: FocalLoss
  • lr: 1e-5
  • Training Layers: layer2, layer3, layer4, fc
  • Epochs: 2

STAGE 2

  • Loss Function: FocalLoss
  • lr: 3e-4
  • Training Layers: layer3, layer4, fc
  • Epochs: 1

STAGE 3

  • Loss Function: FocalLoss
  • lr: 1e-3
  • Training Layers: layer4, fc
  • Epochs: 3

STAGE 4

  • Loss Function: FocalLoss
  • lr: 1e-3
  • Training Layers: fc
  • Epochs: 2