/recurrent-decoding-cell

[AAAI'20] Segmenting Medical MRI via Recurrent Decoding Cell (Spotlight)

Primary LanguagePythonMIT LicenseMIT

Recurrent Decoding Cell

license

This is the PyTorch implementation for AAAI 2020 paper Segmenting Medical MRI via Recurrent Decoding Cell by Ying Wen, Kai Xie, Lianghua He.

network

Overview

Recurrent Decoding Cell (RDC) is a novel feature fusion unit used in the encoder-decoder segmentation network for MRI segmentation. RDC leverages convolutional RNNs (e.g. ConvLSTM, ConvGRU) to memorize the long-term context information from the previous layers in the decoding phase. The RDC based encoder-decoder network named Convolutional Recurrent Decoding Network (CRDN) achieves promising semgmentation reuslts -- 99.34% dice score on BrainWeb, 91.26% dice score on MRBrainS, and 88.13% dice score on HVSMR. The model is also robust to image noise and intensity non-uniformity in medical MRI.

Models Implemented

Enviroments

  • pytorch == 1.1.0
  • torchvision == 0.2.2.post3
  • matplotlib == 2.1.0
  • numpy == 1.11.3
  • tqdm == 4.31.1

One-line installation

pip install -r requirements.txt

Datasets

Usage

Setup config

model:
    arch: <name> [options: 'FCN, SegNet, UNet, VGG16RNN, ResNet50RNN, UNetRNN, VGGUNet, ResNet50UNet, UNetFCN, ResNet50FCN, UNetSegNet']

data:
    dataset: <name> [options: 'BrainWeb, MRBrainS, HVSMR']
    train_split: train
    val_split: val
    path: <path/to/data>

training:
    gpu_idx: 0
    train_iters: 30000
    batch_size: 1
    val_interval: 300
    n_workers: 4
    print_interval: 100
    optimizer:
        name: <optimizer_name> [options: 'sgd, adam, adamax, asgd, adadelta, adagrad, rmsprop']
        lr: 6.0e-4
        weight_decay: 0.0005
    loss:
        name: 'cross_entropy'
    lr_schedule:
        name: <schedule_type> [options: 'constant_lr, poly_lr, multi_step, cosine_annealing, exp_lr']
        <scheduler_keyarg1>:<value>

    # Resume from checkpoint
    resume: <path_to_checkpoint>
    
    # model save path
    model_dir: <path_to_save_model>

testing:
    # trained model path
    trained_model: <path_to_trained_model>

    # segmentation results save path
    path: <path_to_results>
    
    # if show boxplot results
    boxplot: False

To train the model :

run train.py

To test the model :

run test.py

Results

  • Some visualization results of the proposed CRDN and other encoding-decoding methods. vis

  • please refer to the paper for other experiments. (ablation study, comparisons, network robustness)

Acknowledgements

Special thanks for the github repository meetshah1995/pytorch-semseg for providing the semacntic segmentation algorithms in PyTorch.

Citation

Please cite these papers in your publications if it helps your research:

@inproceedings{wen2020segmenting,
  title={Segmenting Medical MRI via Recurrent Decoding Cell.},
  author={Wen, Ying and Xie, Kai and He, Lianghua},
  booktitle={AAAI},
  pages={12452--12459},
  year={2020}
}

For any problems, please contact kxie_shake@outlook.com