Volumetric Segmentation and Characterisation of the Paracingulate Sulcus on MRI Scans
The implementation of my dissertation, the work is an extension of pytorch-3dunet.
Getting Started
Dependencies
- pytorch (0.4.1+)
- torchvision (0.2.1+)
- tensorboardx (1.6+)
- h5py
Setup a new conda environment with the required dependencies via:
conda create -n 3dunet pytorch torchvision tensorboardx h5py -c conda-forge -c pytorch
Activate newly created conda environment via:
source activate 3dunet
TL;DR
Download the pre-trained coordinate detector and the segmentation network model:
To predict on the NIfTI file, run predict.py
with the following command:
python predict.py --model-path ./seg.pytorch --cdmodel-path ./cd.pytorch --test-path ./file.nii.gz
The command will create two masks for LH and RH in the path ./output/
.
Data Preparation
Data Preprocessing
With a given scan, if the size does not match with the intended size, preprocessing can be performed. The script is available in ./datasets/align_acpc.py
. The scan is resampled to 192 x 224 x 192 and 1 x 1 x 1 mm voxel first, and is then aligned to the ACPC plane using a 6 DOF alignment via FSL commands based on the template file ./datasets/template.nii.gz
. The preprocessing step will be automatically performed during prediction, where the output segmentation mask will be based on the processed scan. However, for preparing the training data, the implemented function needs to be manually added to the conversion script mentioned in the upcoming section.
Data Conversion
NIfTI files need to be converted to HDF5 format first for training. An example can be found in ./datasets/h5_converter.py
, which can be executed to generate h5 file for all provided MRI scans. Similar conversion can be done by modifying the file to enable conversion on other files.
Each generated h5 file has three datasets, raw
, label
, and coor
. The raw
and label
sets correspond to the raw image and the label of the scan, and zero-padding is performed to resize the image to 192 x 224 x 192 (divisble by 32). Lastly, the coor
dataset contains the centre coordinate (x, y, z)
for each scan. As not all scans have the PCS, a threshold is applied to only generate the coor
dataset for scans have more PCS voxels than the threshold.
Data Augmentation
After conversion, the augmentation can be performed:
python ./datasets/augmentation.py --input-path ./PCS_h5_files/* --output-path ./aug_PCS_files/ --interval 4
The input-path
and output-path
arguments indicate the path for input and output, and interval
is the step size of shift offset used when performing, which can indirectly affect the augmented sample size.
For example, when an interval of 4 is applied, shifts will be generated from range(0, 25, 4), and augmentation will be made on moving the centre coordinate towards different directions by 0, 4, 8, ... voxels, and crop the surrounding image based on the sliding window size.
IMPORTANT
The file to be augmented must contain the coor
dataset.
Supported Losses
Loss functions
- wce - WeightedCrossEntropyLoss (see 'Weighted cross-entropy (WCE)' in the above paper for a detailed explanation)
- ce - CrossEntropyLoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
) - pce - PixelWiseCrossEntropyLoss (once can specify not only class weights but also per pixel weights in order to give more/less gradient in some regions of the ground truth)
- bce - BCELoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
) - dice - DiceLoss standard Dice loss (see 'Dice Loss' in the above paper for a detailed explanation). Note: if your labels in the training dataset are not very imbalance
e.g. one class having at lease 3 orders of magnitude more voxels than the other use this instead of
GDL
since it worked better in my experiments. - gdl - GeneralizedDiceLoss (one can specify class weights via
--loss-weight <w_1 ... w_k>
)(see 'Generalized Dice Loss (GDL)' in the above paper for a detailed explanation)
Train
usage: train.py [-h] [--checkpoint-dir CHECKPOINT_DIR] [--in-channels
IN_CHANNELS] [--out-channels OUT_CHANNELS]
[--init-channel-number INIT_CHANNEL_NUMBER]
[--layer-order LAYER_ORDER] [--loss LOSS]
[--loss-weight LOSS_WEIGHT [LOSS_WEIGHT ...]]
[--ignore-index IGNORE_INDEX] [--curriculum] [--final-sigmoid]
[--epochs EPOCHS] [--iters ITERS] [--patience PATIENCE]
[--learning-rate LEARNING_RATE] [--weight-decay WEIGHT_DECAY]
[--validate-after-iters VALIDATE_AFTER_ITERS]
[--log-after-iters LOG_AFTER_ITERS] [--resume RESUME]
--train-path TRAIN_PATH [TRAIN_PATH ...]
--train-patch TRAIN_PATCH [TRAIN_PATCH ...]
--train-stride TRAIN_STRIDE [TRAIN_STRIDE ...]
[--raw-internal-path RAW_INTERNAL_PATH]
[--label-internal-path LABEL_INTERNAL_PATH]
[--transformer TRANSFORMER]
[--network NETWORK]
UNet3D training
optional arguments:
-h, --help show this help message and exit
--checkpoint-dir CHECKPOINT_DIR
checkpoint directory
--in-channels IN_CHANNELS
number of input channels (default: 1)
--out-channels OUT_CHANNELS
number of output channels (default: 2)
--init-channel-number INIT_CHANNEL_NUMBER
Initial number of feature maps in the encoder path
which gets doubled on every stage (default: 64)
--layer-order LAYER_ORDER
Conv layer ordering, e.g. 'crg' ->
Conv3D+ReLU+GroupNorm
--loss LOSS Which loss function to use for segmentation network.
Possible values: [bce, ce, wce, dice]. Where bce -
BinaryCrossEntropyLoss (binary classification only),
ce - CrossEntropyLoss (multi-class classification),
wce - WeightedCrossEntropyLoss (multi-class
classification), dice - GeneralizedDiceLoss
(multi-class classification)
--loss-weight LOSS_WEIGHT [LOSS_WEIGHT ...]
A manual rescaling weight given to each class. Can be
used with CrossEntropy or BCELoss. E.g. --loss-weight
0.3 0.3 0.4
--ignore-index IGNORE_INDEX
Specifies a target value that is ignored and does not
contribute to the input gradient
--curriculum use simple Curriculum Learning scheme if ignore_index
is present
--final-sigmoid if True apply element-wise nn.Sigmoid after the last
layer otherwise apply nn.Softmax
--epochs EPOCHS max number of epochs (default: 500)
--iters ITERS max number of iterations (default: 1e5)
--patience PATIENCE number of epochs with no loss improvement after which
the training will be stopped (default: 20)
--learning-rate LEARNING_RATE
initial learning rate (default: 0.0002)
--weight-decay WEIGHT_DECAY
weight decay (default: 0.0001)
--validate-after-iters VALIDATE_AFTER_ITERS
how many iterations between validations (default: 100)
--log-after-iters LOG_AFTER_ITERS
how many iterations between tensorboard logging
(default: 100)
--resume RESUME path to latest checkpoint (default: none); if provided
the training will be resumed from that checkpoint
--train-path TRAIN_PATH [TRAIN_PATH ...]
paths to the training datasets, e.g. --train-path <path1> <path2>
--train-patch TRAIN_PATCH [TRAIN_PATCH ...]
Patch shape for used for training
--train-stride TRAIN_STRIDE [TRAIN_STRIDE ...]
Patch stride for used for training
--raw-internal-path RAW_INTERNAL_PATH
--label-internal-path LABEL_INTERNAL_PATH
--transformer TRANSFORMER
data augmentation class
--network NETWORK
which network to train, cd for coordinate detector
and seg for segmentation network.
For direct training on the whole image without using patch based training, simply assign the train-patch
and train-stride
arguments as the dimension of the image.
Train on coordinate detector:
python train.py --checkpoint-dir ./ckpt/ --epoch 50 --learning-rate 0.0002 --train-path ./PCS_data_h5/* --train-patch 192 224 192 --train-stride 192 224 192 --label-internal-path coor --network cd
Train on segmentation network using Dice loss:
python train.py --checkpoint-dir ./ckpt/ --epoch 10 --learning-rate 0.0002 --train-path ./PCS_data_h5/* --train-patch 63 77 93 --train-stride 63 77 93 --network seg --loss dice
To resume training the segmentation from the last checkpoint:
python train.py --resume ./ckpt/seg_ckpt.pytorch --epoch 10 --learning-rate 0.0002 --train-path ./PCS_data_h5/* --train-patch 63 77 93 --train-stride 63 77 93 --network seg --loss dice
IMPORTANT
In order to train with BinaryCrossEntropy
the label data has to be 4D! (one target binary mask per channel). --final-sigmoid
has to be given when training the network with BinaryCrossEntropy
(and similarly --final-sigmoid
has to be passed to the predict.py
if the network was trained with --final-sigmoid
)
DiceLoss
and GeneralizedDiceLoss
support both 3D and 4D target (if the target is 3D it will be automatically expanded to 4D, i.e. each class in separate channel, before applying the loss).
Test
usage: predict.py [-h] --cdmodel-path MODEL_PATH --model-path MODEL_PATH
[--in-channels IN_CHANNELS] [--out-channels OUT_CHANNELS]
[--init-channel-number INIT_CHANNEL_NUMBER]
[--layer-order LAYER_ORDER] [--final-sigmoid] --test-path
TEST_PATH [--raw-internal-path RAW_INTERNAL_PATH] --patch
PATCH [PATCH ...] --stride STRIDE [STRIDE ...]
[--report-metrics] [--output-path OUTPUT_PATH]
3D U-Net predictions
optional arguments:
-h, --help show this help message and exit
--cdmodel-path MODEL_PATH
path to the coordinate detector model
--model-path MODEL_PATH
path to the segmentation model
--in-channels IN_CHANNELS
number of input channels (default: 1)
--out-channels OUT_CHANNELS
number of output channels (default: 2)
--init-channel-number INIT_CHANNEL_NUMBER
Initial number of feature maps in the encoder path
which gets doubled on every stage (default: 64)
--layer-order LAYER_ORDER
Conv layer ordering, e.g. 'crg' ->
Conv3D+ReLU+GroupNorm
--final-sigmoid if True apply element-wise nn.Sigmoid after the last
layer otherwise apply nn.Softmax
--test-path TEST_PATH
path to the test dataset
--raw-internal-path RAW_INTERNAL_PATH
--patch PATCH [PATCH ...]
Patch shape for used for prediction on the test set
--stride STRIDE [STRIDE ...]
Patch stride for used for prediction on the test set
--report-metrics
Whether to print metrics for each prediction
--output-path OUTPUT_PATH
The output path to generate the nifti file
To predict and test on h5 files, the following command can be executed to report metrics (e.g. Dice score, loss...) and save predicted segmentation as two NIfTI files for each scan:
python predict.py --model-path ./ckpt/seg_ckpt.pytorch --cdmodel-path ./ckpt/cd_ckpt.pytorch --test-path ./PCS_data_h5/* --report-metrics
To simply generate NIfTI file of prediction:
python predict.py --model-path ./seg.pytorch --cdmodel-path ./cd.pytorch --test-path ./file.nii.gz
IMPORTANT
Image preprocessing is performed when the given NIfTI file has a mismatched dimension, along with the predicted segmentation mask for LH and RH, upon preprocessing the processed NIfTI scan file will also be generated in the output path.