/keras-unet-collection

The Tensorflow, Keras implementation of U-net, V-net, U-net++, R2U-net, Attention U-net, ResUnet-a, U^2-Net, and UNET 3+ with optional ImageNet-trained backbones.

Primary LanguagePythonMIT LicenseMIT

keras-unet-collection

PyPI version PyPI license Maintenance

The tensorflow.keras implementation of U-net, V-net, U-net++, R2U-net, Attention U-net, ResUnet-a, U^2-Net, and UNET 3+ with optional ImageNet-trained backbones.


keras_unet_collection.models contains functions that configure keras models with hyper-parameter options.

  • Pre-trained ImageNet backbones are supported for U-net, U-net++, Attention U-net, and UNET 3+.
  • Deep supervision is supported for U-net++, UNET 3+, and U^2-Net.
  • See the User guide for other options and use cases.
keras_unet_collection.models Name Reference
unet_2d U-net Ronneberger et al. (2015)
vnet_2d V-net (modified for 2-d inputs) Milletari et al. (2016)
unet_plus_2d U-net++ Zhou et al. (2018)
r2_unet_2d R2U-Net Alom et al. (2018)
att_unet_2d Attention U-net Oktay et al. (2018)
resunet_a_2d ResUnet-a Diakogiannis et al. (2020)
u2net_2d U^2-Net Qin et al. (2020)
unet_3plus_2d UNET 3+ Huang et al. (2020)

keras_unet_collection.base contains functions that build the base architecture (i.e., without model heads) of Unet variants for model customization and debugging.

keras_unet_collection.base Notes
unet_2d_base, vnet_2d_base, unet_plus_2d_base, r2_unet_2d_base, att_unet_2d_base, resunet_a_2d_base, u2net_2d_base, unet_3plus_2d_base Functions that accept an input tensor and hyper-parameters of the corresponded model, and produce output tensors of the base architecture.

keras_unet_collection.activations and keras_unet_collection.losses provide additional activation layers and loss functions.

keras_unet_collection.activations Name Reference
GELU Gaussian Error Linear Units (GELU) Hendrycks et al. (2016)
Snake Snake activation Liu et al. (2020)
keras_unet_collection.losses Name Reference
dice Dice loss Sudre et al. (2017)
tversky Tversky loss Hashemi et al. (2018)
focal_tversky Focal Tversky loss Abraham et al. (2019)
triplet_1d Semi-hard triplet loss (experimental)
crps2d_tf CRPS loss (experimental)

Installation and usage

pip install keras-unet-collection

from keras_unet_collection import models
# e.g. models.unet_2d(...)
  • Note: Currently supported backbone models are: VGG[16,19], ResNet[50,101,152], ResNet[50,101,152]V2, DenseNet[121,169,201], and EfficientNetB[0-7]. See Keras Applications for details.

  • Note: This package is planned for major updates. For versions prior to 0.1.0, backward compatibility is not ensured.

  • Note: Neural networks produced by this package may not be compatible with other pre-trained models of the same name. Training from scratch is recommended.

  • Jupyter notebooks are provided as examples:

    • Attention U-net with VGG16 backbone [link].

    • UNET 3+ with deep supervision and classification guided module [link].

  • Changelog

Dependencies

  • TensorFlow 2.3.0, Keras 2.4.0, Numpy 1.18.2.

  • (Optional for examples) Pillow, matplotlib, etc.

Overview

U-net is a convolutional neural network with encoder-decoder architecture and skip-connections, loosely defined under the concept of "fully convolutional networks." U-net was originally proposed for the semantic segmentation of medical images and is modified for solving a wider range of gridded learning problems.

U-net and many of its variants take three or four-dimensional tensors as inputs and produce outputs of the same shape. One technical highlight of these models is the skip-connections from downsampling to upsampling layers, which benefit the reconstruction of high-resolution, gridded outputs.

Contact

Yingkai (Kyle) Sha <yingkai@eoas.ubc.ca> <yingkaisha@gmail.com>

License

MIT License