PyTorch implementation of 3D U-Net based on:
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger
- Linux
- NVIDIA GPU
- CUDA CuDNN
- pytorch (0.4.1+)
- torchvision (0.2.1+)
- tensorboardx (1.6+)
- h5py
- scipy
- scikit-image
- pytest
Setup a new conda environment with the required dependencies via:
conda create -n 3dunet pytorch torchvision tensorboardx h5py scipy scikit-image pyyaml pytest -c conda-forge -c pytorch
Activate newly created conda environment via:
source activate 3dunet
For a detailed explanation of the loss functions used see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso
- wce - WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the above paper for a detailed explanation)
- ce - CrossEntropyLoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
) - pce - PixelWiseCrossEntropyLoss (once can specify not only class weights but also per pixel weights in order to give more/less gradient in some regions of the ground truth)
- bce - BCELoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
) - dice - DiceLoss standard Dice loss (see 'Dice Loss' in the above paper for a detailed explanation). Note: if your labels in the training dataset are not very imbalance
e.g. one class having at lease 3 orders of magnitude more voxels than the other use this instead of
GDL
since it worked better in my experiments. - gdl - GeneralizedDiceLoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
)(see 'Generalized Dice Loss (GDL)' in the above paper for a detailed explanation)
- iou - Mean intersection over union
- dice - Dice Coefficient (computes per channel Dice Coefficient and returns the average)
- ap - Average Precision (normally used for evaluating instance segmentation, however it can be used when the 3D UNet is used to predict the boundary signal from the instance segmentation ground truth)
- rand - Adjusted Rand Score
If not specified iou
will be used by default.
E.g. fit to randomly generated 3D volume and random segmentation mask from random_label3D.h5 run:
python train.py --config resources/train_config.yaml
See the train_config.yaml for more info.
In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config.yaml.
The HDF5 files should contain the raw/label data sets in the following axis order: DHW
(in case of 3D) CDHW
(in case of 4D).
Monitor progress with Tensorboard tensorboard --logdir ./3dunet/logs/ --port 8666
(you need tensorboard
installed in your conda env).
In order to train with BinaryCrossEntropy
, DiceLoss
or GeneralizedDiceLoss
the label data has to be 4D (one target binary mask per channel). In case of DiceLoss
and GeneralizedDiceLoss
the final score is the average across channels.
final_sigmoid=True
has to be present in the config when training the network with any of the above losses (and similarly final_sigmoid=True
has to be passed to the predict.py
if the network was trained with final_sigmoid=True
)
Test on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5.
python predict.py --config resources/test_config.yaml
Prediction masks will be saved to resources/random_label3D_probabilities.h5
.
In order to predict your own raw dataset provide the path to your model as well as paths to HDF5 test datasets in the test_config.yaml.
In order to avoid block artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride
params lead to overlapping blocks, e.g. patch: []64 128 128] stride: []32 96 96]
will give you a 'halo' of 32 voxels in each direction.