/keras_UNET_segmentation

Keras implementation of a UNET for image segmentation

Primary LanguagePythonMIT LicenseMIT

Keras implementation of a UNET for image segmentation

This repository provides a keras implementation of a UNET and tests it by performing segmentation on publicly available data.

Corresponding blog post:

The results presented in this repository are explained in detail in this blog post.

Data: we will use the C. Elegans data available from the Broad Bioimage Benchmark Collection

  • Input images: download and extract into ./data, then delete the '_MACOSX' subfolder within BBBC010_v2_images. Each input image corresponds to two input channels (Brightfield/GFP). Each channel corresponds to a separate grayscale image.
  • Target masks: download and extract into ./data

The goal is to predict the segmentation mask of each image, based on the two input channels. An example image is shown below:

demo image

List of files:

  • load_data.py: loads the data, splits into training, validation and test sets, and saves them to disk as numpy arrays. The images are resized to 400x400 pixels and normalized to [0,1].
  • unet_train.py: trains a UNET implemented with keras on the training data. Saves the best model according to the performance on the validation set.
  • unet_evaluate.py: evaluates the performance of the trained model on the left out test set.
  • mcd_unet_train.py: trains a Monte Carlo Dropout (MCD) UNET implemented with keras on the training data. Saves the best model according to the performance on the validation set. The advantage of the MCD UNET is its ability to generate segmentation maps, as well as uncertainty estimates.
  • mcd_unet_evaluate.py: evaluates the performance of the trained MCD UNET model on the left out test set.

Results:

The trained UNET network achieves a median dice score of 92.4% on the left out test set, while the MCD UNET achieves a similar dice score of 90.5% while also estimating the uncertainty of the predicted segmentation maps. All test set predictions for UNET can be found here, while all test set predictions for MCD UNET can be found here. The results of the MCD unet correspond to averaged results of T=20 models with dropout probability of 50% during inference time (during model.predict()), similar to what is presented in DeVries et al., 2018.

The dice scores achieved by both models on the test set are summarized visually as follows: dice both

Some exemplary results for both models are the following:

UNET

Here we can see the input channels, as well as the true and predicted segmentation masks.

unet example











Next, for the MCD UNET we can see the segmentation, as well as its corresponding uncertainty:

MCD UNET

mcd unet example