PyTorch implementation of a standard 3D U-Net based on:
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation Özgün Çiçek et al.
as well as Residual 3D U-Net based on:
Superhuman Accuracy on the SNEMI3D Connectomics Challenge Kisuk Lee et al.
- Linux
- NVIDIA GPU
- CUDA CuDNN
- pytorch (0.4.1+)
- torchvision (0.2.1+)
- tensorboardx (1.6+)
- h5py
- scipy
- scikit-image
- pytest
Setup a new conda environment with the required dependencies via:
conda create -n 3dunet pytorch torchvision tensorboardx h5py scipy scikit-image pyyaml pytest -c conda-forge -c pytorch
Activate newly created conda environment via:
source activate 3dunet
- in order to train standard 3D U-Net specify
name: UNet3D
in themodel
section of the config file - in order to train Residual U-Net specify
name: ResidualUNet3D
in themodel
section of the config file
For a detailed explanation of the loss functions used see: Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations Carole H. Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, M. Jorge Cardoso
- WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the above paper for a detailed explanation; one can specify class weights via
weight: [w_1, ..., w_k]
in theloss
section of the config) - CrossEntropyLoss (one can specify class weights via
weight: [w_1, ..., w_k]
in theloss
section of the config) - PixelWiseCrossEntropyLoss (one can specify not only class weights but also per pixel weights in order to give more/less gradient in some regions of the ground truth)
- BCEWithLogitsLoss
- DiceLoss standard Dice loss (see 'Dice Loss' in the above paper for a detailed explanation).
- GeneralizedDiceLoss (see 'Generalized Dice Loss (GDL)' in the above paper for a detailed explanation; one can specify class weights via
weight: [w_1, ..., w_k]
in theloss
section of the config). Note: use this loss function only if the labels in the training dataset are very imbalanced e.g. one class having at lease 3 orders of magnitude more voxels than the others. Otherwise use standard DiceLoss which works better than GDL most of the time.
- MeanIoU - Mean intersection over union
- DiceCoefficient - Dice Coefficient (computes per channel Dice Coefficient and returns the average)
- BoundaryAveragePrecision - Average Precision (normally used for evaluating instance segmentation, however it can be used when the 3D UNet is used to predict the boundary signal from the instance segmentation ground truth)
- AdaptedRandError - Adapted Rand Error (see http://brainiac2.mit.edu/SNEMI3D/evaluation for a detailed explanation)
If not specified MeanIoU
will be used by default.
E.g. fit to randomly generated 3D volume and random segmentation mask from random_label3D.h5 run:
python train.py --config resources/train_config_ce.yaml # train with CrossEntropyLoss
or:
python train.py --config resources/train_config_dice.yaml # train with DiceLoss
See the train_config_ce.yaml for more info.
In order to train on your own data just provide the paths to your HDF5 training and validation datasets in the train_config_ce.yaml.
The HDF5 files should contain the raw/label data sets in the following axis order: DHW
(in case of 3D) CDHW
(in case of 4D).
Monitor progress with Tensorboard tensorboard --logdir ./3dunet/logs/ --port 8666
(you need tensorflow
installed in your conda env).
-
In order to train with
BCEWithLogitsLoss
,DiceLoss
orGeneralizedDiceLoss
the label data has to be 4D (one target binary mask per channel). If you have a 3D binary data (foreground/background), you can just changeToTensor
transform for the label to containexpand_dims: true
, see e.g. train_config_dice.yaml. -
When training with binary-based losses (
BCEWithLogitsLoss
,DiceLoss
,GeneralizedDiceLoss
)final_sigmoid=True
has to be present in the training config, since every output channel gives the probability of the foreground. When training with cross entropy based losses (WeightedCrossEntropyLoss
,CrossEntropyLoss
,PixelWiseCrossEntropyLoss
) setfinal_sigmoid=False
so thatSoftmax
normalization is applied to the output.
Test on randomly generated 3D volume (just for demonstration purposes) from random_label3D.h5.
python predict.py --config resources/test_config_ce.yaml
or if you trained with DiceLoss
:
python predict.py --config resources/test_config_dice.yaml
Prediction masks will be saved to resources/random_label3D_probabilities.h5
.
In order to predict your own raw dataset provide the path to your model as well as paths to HDF5 test datasets in the test_config_ce.yaml.
In order to avoid block artifacts in the output prediction masks the patch predictions are averaged, so make sure that patch/stride
params lead to overlapping blocks, e.g. patch: [64 128 128] stride: [32 96 96]
will give you a 'halo' of 32 voxels in each direction.
If you want to contribute back, please make a pull request.
If you use this code for your research, please cite as:
Adrian Wolny. (2019, May 7). wolny/pytorch-3dunet: PyTorch implementation of 3D U-Net (Version v1.0.0). Zenodo. http://doi.org/10.5281/zenodo.2671581