/chest-X-rays-multiclass-densenet-attention

Pytorch code for multi-label classification of chest X-rays with DenseNet, ResNet and DenseNet-121 with attention mechanism. Heatmap generation using Grad-CAM has also been implemented for DenseNet-121 and DenseNet-121 with attention model

Primary LanguagePython

Multi-Label Classification and Visual Highlight of Chest X-ray Images using Neural Networks with Attention Mechanism and Grad-CAM

Marcus Hwai Yik Tan, Xiaohan Tian, Wing Chan, Joshua Ceaser
University of Illinois, Urbana-Champaign

Quick Start

  • Set ALL_IMAGE_DIR to the folder containing the X-ray images
  • Set BASE_PATH_LABELS to the folder containing the lists of training, validation and test image file names
  • Run either of the following notebooks or scripts: t01-multilabel-main-test.ipynb, t01-multilabel-non_image_features-main-test.ipynb, t01-multilabel-main-val.ipynb, t01-multilabel-non_image_features-main-val.ipynb, t01-multilabel-main-test.py, t01-multilabel-non_image_features-main-test.py, t01-multilabel-main-val.py, t01-multilabel-non_image_features-main-val.py

Guide

For multi-label classification:

  • p02-dataset-selection-multilabel.ipynb: This notebook can be skipped since the files containing the lists of selected images for the final report are already included in the "labels" folder. The files are train_val_A.csv, train_A_x.csv (x=1,2,3), val_A_x.csv (x=1,2,3) and test_A.csv. This notebook selects a subset of images for training, validation and test lists. Multiple training/validation splits are generated. The default folder is "labels", where Data_Entry_2017_v2020.csv is also located.
  • t01-multilabel-main-val.ipynb: train and evaluate model on multiple training, validation splits
  • t01-multilabel-main-val.py: Python version of t01-multilabel-main-val.ipynb
  • t01-multilabel-main-test.ipynb: train model on training+validation dataset and evaluate model on a test dataset
  • t01-multilabel-main-test.py: Python version of t01-multilabel-main-test.ipynb

For multi-label classification with non-image features:

  • p02-dataset-add_non_image_features.ipynb: Append non-image features to existing training, validation and test lists generated by p02-dataset-selection-multilabel.ipynb
  • t01-multilabel-non_image_features-main-val.ipynb: same function as t01-multilabel-main-val.ipynb but with non-image features as additional inputs
  • t01-multilabel-non_image_features-main-val.py: Python version of t01-multilabel-non_image_features-main-val.ipynb
  • t01-multilabel-non_image_features-main-test.ipynb: same function as t01-multilabel-main-test.ipynb but with non-image features as additional inputs
  • t01-multilabel-non_image_features-main-test.py: Python version of t01-multilabel-non_image_features-main-test.ipynb

For statistics and evaluation:

  • t01-multilabel-test.py: load a saved model and evaluate on a test dataset.
  • p02-dataset-stats.ipynb: analysis chest x-ray dataset and draw statistic charts.
  • pp01-postprocess-performance.ipynb: postprocess the performance stats in the performance directory.

For heatmap generation using Grad-CAM:

  • t03-multilabel-heatmap-densenet121-v2.ipynb: Load a saved model and draw heatmap image from given input image. Please note the MODEL_NAME can only be densenet121. A DenseNet-121 model trained on the images in the train_val_A.csv list for 8 epochs is provided in the models folder
  • t03-multilabel-heatmap-densenet121attA-v2.ipynb: Load a saved model and draw heatmap image from given input image. Please note the MODEL_NAME can only be densenet121attA. A DenseNet-121-attA model trained on the images in the train_val_A.csv list for 8 epochs is provided in the models folder

Python files containing standard and customized models:

Folder description

  • labels: contains "Data_Entry_2017_v2020.csv" and the lists of training, validation and test subsets of images used in the final report
  • models: contains two trained models -- DenseNet-121 and DenseNet-121-attA that can be used to generate the heatmaps
  • heatmaps: output folder for the heatmaps

Tested Environment

Local

CPU: AMD Ryzen 5 4600H
GPU: NV GTX 1650 / NV RTX 2060 Max-Q

AWS

c5 series
p2 series

Dependencies

  • python3.7+
  • pytorch, pytorch vision, PIL, numpy, pandas, scikit-learn, matplotlib, importlib, datetime,time

Installing

For local

  • Please install PyTorch 1.8.x via mini-conda
  • If you would like to enable CUDA acceleration, install CUDA toolkit accordingly

For AWS

  • Choose AWS Deep Learning AMI when creating EC2 instances
  • Use any venv with PyTorch 1.8.x to run the notebooks

Authors

  • [Marcus Hwai Yik Tan]
  • [Xiaohan Tian]
  • [Wing Chan]
  • [Joshua Ceaser]