Hit star ⭐ if you find my work useful.
- UNet 3+ for Image Segmentation in Tensorflow Keras.
Requirements
- Python >= 3.6
- TensorFlow >= 2.4
- CUDA 9.2, 10.0, 10.1, 10.2, 11.0
This code base is tested against above-mentioned Python and TensorFlow versions. But it's expected to work for latest versions too.
- Clone code
git clone https://github.com/hamidriasat/UNet-3-Plus.git UNet3P
cd UNet3P
- Install other requirements.
pip install -r requirements.txt
- checkpoint: Model checkpoint and logs directory
- configs: Configuration file
- data: Dataset files (see Data Preparation) for more details
- data_preparation: For LiTS data preparation and data verification
- losses: Implementations of UNet3+ hybrid loss function and dice coefficient
- models: Unet3+ model files
- utils: Generic utility functions
- data_generator.py: Data generator for training, validation and testing
- evaluate.py: Evaluation script to validate accuracy on trained model
- predict.ipynb: Prediction file used to visualize model output inside notebook(helpful for remote server visualization)
- predict.py: Prediction script used to visualize model output
- train.py: Training script
Configurations are passed through yaml
file. For more details on config file read here.
- This code can be used to reproduce UNet3+ paper results on LiTS - Liver Tumor Segmentation Challenge.
- You can also use it to train UNet3+ on custom dataset.
For dataset preparation read here.
This repo contains all three versions of UNet3+.
# | Description | Model Name | Training Supported |
---|---|---|---|
1 | UNet3+ Base model | unet3plus | ✓ |
2 | UNet3+ with Deep Supervision | unet3plus_deepsup | ✓ |
3 | UNet3+ with Deep Supervision and Classification Guided Module | unet3plus_deepsup_cgm | ✗ |
Here you can find UNet3+ hybrid loss.
To train a model call train.py
with required model type and configurations .
e.g. To train on base model run
python train.py MODEL.TYPE=unet3plus
To evaluate the trained models call evaluate.py
.
e.g. To calculate accuracy of trained UNet3+ Base model on validation data run
python evaluate.py MODEL.TYPE=unet3plus
Multi Gpu Training
Our code support multi gpu training
using Tensorflow Distributed MirroredStrategy
.
By default, training and evaluation is done on only one gpu. To enable multiple gpus you have to explicitly set
USE_MULTI_GPUS
values.
e.g. To train on all available gpus run
python train.py ... USE_MULTI_GPUS.VALUE=True USE_MULTI_GPUS.GPU_IDS=-1
For GPU_IDS
two options are available. It could be either integer or list of integers.
- In case Integer:
- If integer value is -1 then it uses all available gpus.
- Otherwise, if positive number, then use given number of gpus.
- In case list of Integers: each integer will be considered as gpu id e.g. [4,5,7] means use gpu 5,6 and 8 for training/evaluation
For visualization two options are available
In both cases mask is optional
You can visualize results through predict.ipynb notebook, or you can also override these settings
through command line and call predict.py
- Visualize from directory
In case of visualization from directory, it's going to make prediction and show all images from given directory. Override the validation data paths and make sure the directory paths are relative to the project base/root path e.g.
python predict.py MODEL.TYPE=unet3plus ^
DATASET.VAL.IMAGES_PATH=/data/val/images/ ^
DATASET.VAL.MASK_PATH=/data/val/mask/
- Visualize from list
In case of visualization from list, each list element should contain absolute path of image/mask. For absolute paths training data naming convention does not matter you can pass whatever naming convention you have, just make sure images, and it's corresponding mask are on same index.
e.g. To visualize model results on two images along with their corresponding mask, run
python predict.py MODEL.TYPE=unet3plus ^
DATASET.VAL.IMAGES_PATH=[^
H:\\Projects\\UNet3P\\data\\val\images\\image_0_48.png,^
H:\\Projects\\UNet3P\\data\\val\images\\image_0_21.png^
] DATASET.VAL.MASK_PATH=[^
H:\\Projects\\UNet3P\\data\\val\\mask\\mask_0_48.png,^
H:\\Projects\\UNet3P\\data\\val\\mask\\mask_0_21.png^
]
For your own data visualization set SHOW_CENTER_CHANNEL_IMAGE=False
. This should set True for only UNet3+ LiTS data.
These commands are tested on Windows. For Linux replace ^
with \ and replace H:\\Projects
with your own base path
Note: Don't add space between list elements, it will create problem with Hydra.
In both cases if mask is not available just set the mask path to None
python predict.py DATASET.VAL.IMAGES_PATH=... DATASET.VAL.MASK_PATH=None
This branch is in development mode. So changes are expected.
TODO List
- Complete README.md
- Add requirements file
- Add Data augmentation
- Add multiprocessing in LiTS data preprocessing
- Load data through NVIDIA DALI
We appreciate any feedback so reporting problems, and asking questions are welcomed here.
Licensed under MIT License