Multi-Label Image Classification of the Chest X-Rays In Pytorch
- torch >= 0.4
- torchvision >= 0.2.2
- opencv-python
- numpy >= 1.7.3
- matplotlib
- tqdm
NIH Chest X-ray Dataset is used for Multi-Label Disease Classification of of the Chest X-Rays. There are a total of 15 classes (14 diseases, and one for 'No findings') Images can be classified as "No findings" or one or more disease classes:
- Atelectasis
- Consolidation
- Infiltration
- Pneumothorax
- Edema
- Emphysema
- Fibrosis
- Effusion
- Pneumonia
- Pleural_thickening
- Cardiomegaly
- Nodule Mass
- Hernia
There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing.
Pretrained Resnet50 model is used for Transfer Learning on this new image dataset.
There is a choice of loss function
- Focal Loss (default)
- Binary Cross Entropy Loss or BCE Loss
-
Following are the layers which are set to trainable-
- layer2
- layer3
- layer4
- fc
Terminal Code:
python main.py
-
A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the
-
epochs (number of epochs the model has been trained till that time)
-
model (architecture and the learnt weights of the model)
-
lr_scheduler_state_dict (state_dict of the lr_scheduler)
-
losses_dict (a dictionary containing the following loses)
- mean train epoch losses for all the epochs
- mean val epoch losses for all the epochs
- batch train loss for all the training batches
- batch train loss for all the val batches
-
Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The 'stage' parameter can be passed from the terminal using the argument --stage STAGE
Terminal Code:
python main.py --resume --ckpt checkpoint_file.pth --stage 2
Training the model will create a models directory and will save the checkpoints in there.
A Saved Checkpoint needs to be loaded using the --ckpt argument and --test argument needs to be passed for activating the Test Mode
Terminal Code:
python main.py --test --ckpt checkpoint_file.pth
The model achieved the average ROC AUC Score of 0.73241 on all classes(excluding "No findings" class) after training in the following stages-
- Loss Function: FocalLoss
- lr: 1e-5
- Training Layers: layer2, layer3, layer4, fc
- Epochs: 2
- Loss Function: FocalLoss
- lr: 3e-4
- Training Layers: layer3, layer4, fc
- Epochs: 1
- Loss Function: FocalLoss
- lr: 1e-3
- Training Layers: layer4, fc
- Epochs: 3
- Loss Function: FocalLoss
- lr: 1e-3
- Training Layers: fc
- Epochs: 2