/U-Net-TensorFlow

TensorFlow implementation of the U-Net.

Primary LanguagePython

U-Net-TensorFlow

This repository is a TensorFlow implementation of the "U-Net: Convolutional Networks for Biomedical Image Segmentation," MICCAI2015. It completely follows the original U-Net paper.

EM Segmentation Challenge Dataset

Requirements

  • tensorflow 1.13.1
  • python 3.5.3
  • numpy 1.15.2
  • scipy 1.1.0
  • tifffile 2019.3.18
  • opencv 3.3.1
  • matplotlib 2.2.2
  • elasticdeform 0.4.4
  • scikit-learn 0.20.0

Implementation Details

This implementation completely follows the original U-Net paper from the following aspects:

  • Input image size 572 x 572 x 1 vs output labeled image 388 x 388 x 2
  • Upsampling used fractional strided convolusion (deconv)
  • Reflection mirror padding is used for the input image
  • Data augmentation: random translation, random horizontal and vertical flip, random rotation, and random elastic deformation
  • Loss function includes weighted cross-entropy loss and regularization term
  • Weight map is calculated using equation 2 of the original paper
  • In test stage, this implementation achieves average of the 7 rotated versions of the input data

Examples of the Data Augmentation

  • Random Translation
  • Random Horizontal and Vertical Flip
  • Random Rotation
  • Random Elastic Deformation

Fundamental of the Different Sized Input and Output Images in Training Process

  • Reflected mirror padding is utilized first (white lines indicate boundaries of the image)
  • Randomly cropping the input image, label image, and weighted image
  • Blue rectangle region of the input image and red rectangle of the weight map are the inputs of the U-Net in the training, and the red rectangle of the labeled image is the ground-truth of the network.

Test Paradigm

  • In test stage, each test image is the average of the 7 rotated version of the input data. The final prediction is the averaging the 7 predicted restuls.

  • For each rotated image, the four regions are extracted, top left, top right, bottom left, and bottom right of the each image to go through the U-Net, and the prediction is calculated averaging the overlapping scores of the four results

Note: White lines indicate boundaries of the image.

Note: The prediciton results of the EM Segmentation Challenge Test Dataset

Download Dataset

Download the EM Segmetnation Challenge dataset from ISBI challenge homepage.

Documentation

Directory Hierarchy

.
│   U-Net
│   ├── src
│   │   ├── dataset.py
│   │   ├── main.py
│   │   ├── model.py
│   │   ├── preprocessing.py
│   │   ├── solver.py
│   │   ├── tensorflow_utils.py
│   │   └── utils.py
│   Data
│   └── EMSegmentation
│   │   ├── test-volume.tif
│   │   ├── train-labels.tif
│   │   ├── train-wmaps.npy (generated in preprocessing)
│   │   └── train-volume.tif

Preprocessing

Weight map need to be calculated using segmentaion labels in training data first. Calculaing wegith map using on-line method in training will slow down processing time. Therefore, calculating and saving weighted map first, the weight maps are augmented according to the input and label images. Use preprocessing.py to calculate weight maps. Example usage:

python preprocessing.py

Training U-Net

Use main.py to train the U

python main.py
  • gpu_index: gpu index if you have multiple gpus, default: 0
  • dataset: dataset name, default: EMSegmentation
  • batch_size: batch size for one iteration, default: 4
  • is_train: training or inference (test) mode, default: True (training mode)
  • learning_rate: initial learning rate for optimizer, default: 1e-3
  • weight_decay: weight decay for model to handle overfitting, default: 1e-4
  • iters: number of iterations, default: 20,000
  • print_freq: print frequency for loss information, default: 10
  • sample_freq: sample frequence for checking qualitative evaluation, default: 100
  • eval_freq: evaluation frequency for evluation of the batch accuracy, default: 200
  • load_model: folder of saved model that you wish to continue training, (e.g. 20190524-1606), default: None

Test U-Net

Use main.py to test the models. Example usage:

python main.py --is_train=False --load_model=folder/you/wish/to/test/e.g./20190524-1606

Please refer to the above arguments.

Tensorboard Visualization

Note: The following figure shows data loss, weighted data loss, regularization term, and total loss during training process. The batch accuracy also is given in tensorboard.

Citation

  @misc{chengbinjin2019u-nettensorflow,
    author = {Cheng-Bin Jin},
    title = {U-Net Tensorflow},
    year = {2019},
    howpublished = {\url{https://github.com/ChengBinJin/U-Net-TensorFlow}},
    note = {commit xxxxxxx}
  }

Attributions/Thanks

License

Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.