/UNet-3-Plus

A Full-Scale Connected UNet for Medical Image Segmentation

Primary LanguageJupyter NotebookMIT LicenseMIT

UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation Hits license

PWC

Hit star ⭐ if you find my work useful.

Table of Contents

Installation

Requirements

  • Python >= 3.6
  • TensorFlow >= 2.4
  • CUDA 9.2, 10.0, 10.1, 10.2, 11.0

This code base is tested against above-mentioned Python and TensorFlow versions. But it's expected to work for latest versions too.

  • Clone code
git clone https://github.com/hamidriasat/UNet-3-Plus.git UNet3P
cd UNet3P
  • Install other requirements.
pip install -r requirements.txt

Code Structure

  • checkpoint: Model checkpoint and logs directory
  • configs: Configuration file
  • data: Dataset files (see Data Preparation) for more details
  • data_preparation: For LiTS data preparation and data verification
  • losses: Implementations of UNet3+ hybrid loss function and dice coefficient
  • models: Unet3+ model files
  • utils: Generic utility functions
  • data_generator.py: Data generator for training, validation and testing
  • evaluate.py: Evaluation script to validate accuracy on trained model
  • predict.ipynb: Prediction file used to visualize model output inside notebook(helpful for remote server visualization)
  • predict.py: Prediction script used to visualize model output
  • train.py: Training script

Config

Configurations are passed through yaml file. For more details on config file read here.

Data Preparation

For dataset preparation read here.

Models

This repo contains all three versions of UNet3+.

# Description Model Name Training Supported
1 UNet3+ Base model unet3plus
2 UNet3+ with Deep Supervision unet3plus_deepsup
3 UNet3+ with Deep Supervision and Classification Guided Module unet3plus_deepsup_cgm

Here you can find UNet3+ hybrid loss.

Training & Evaluation

To train a model call train.py with required model type and configurations .

e.g. To train on base model run

python train.py MODEL.TYPE=unet3plus

To evaluate the trained models call evaluate.py.

e.g. To calculate accuracy of trained UNet3+ Base model on validation data run

python evaluate.py MODEL.TYPE=unet3plus

Multi Gpu Training

Our code support multi gpu training using Tensorflow Distributed MirroredStrategy . By default, training and evaluation is done on only one gpu. To enable multiple gpus you have to explicitly set USE_MULTI_GPUS values. e.g. To train on all available gpus run

python train.py ... USE_MULTI_GPUS.VALUE=True USE_MULTI_GPUS.GPU_IDS=-1 

For GPU_IDS two options are available. It could be either integer or list of integers.

  • In case Integer:
    • If integer value is -1 then it uses all available gpus.
    • Otherwise, if positive number, then use given number of gpus.
  • In case list of Integers: each integer will be considered as gpu id e.g. [4,5,7] means use gpu 5,6 and 8 for training/evaluation

Inference Demo

For visualization two options are available

  1. Visualize from directory
  2. Visualize from list

In both cases mask is optional

You can visualize results through predict.ipynb notebook, or you can also override these settings through command line and call predict.py

  1. Visualize from directory

In case of visualization from directory, it's going to make prediction and show all images from given directory. Override the validation data paths and make sure the directory paths are relative to the project base/root path e.g.

python predict.py MODEL.TYPE=unet3plus ^
DATASET.VAL.IMAGES_PATH=/data/val/images/ ^
DATASET.VAL.MASK_PATH=/data/val/mask/
  1. Visualize from list

In case of visualization from list, each list element should contain absolute path of image/mask. For absolute paths training data naming convention does not matter you can pass whatever naming convention you have, just make sure images, and it's corresponding mask are on same index.

e.g. To visualize model results on two images along with their corresponding mask, run

python predict.py MODEL.TYPE=unet3plus ^
DATASET.VAL.IMAGES_PATH=[^
H:\\Projects\\UNet3P\\data\\val\images\\image_0_48.png,^
H:\\Projects\\UNet3P\\data\\val\images\\image_0_21.png^
] DATASET.VAL.MASK_PATH=[^
H:\\Projects\\UNet3P\\data\\val\\mask\\mask_0_48.png,^
H:\\Projects\\UNet3P\\data\\val\\mask\\mask_0_21.png^
]

For your own data visualization set SHOW_CENTER_CHANNEL_IMAGE=False. This should set True for only UNet3+ LiTS data.

These commands are tested on Windows. For Linux replace ^ with \ and replace H:\\Projects with your own base path

Note: Don't add space between list elements, it will create problem with Hydra.

In both cases if mask is not available just set the mask path to None

python predict.py DATASET.VAL.IMAGES_PATH=... DATASET.VAL.MASK_PATH=None

This branch is in development mode. So changes are expected.

TODO List

  • Complete README.md
  • Add requirements file
  • Add Data augmentation
  • Add multiprocessing in LiTS data preprocessing
  • Load data through NVIDIA DALI

We appreciate any feedback so reporting problems, and asking questions are welcomed here.

Licensed under MIT License