/keras-unet-collection

The Tensorflow, Keras implementation of U-net, V-net, U-net++, UNET 3+, Attention U-net, R2U-net, ResUnet-a, U^2-Net, TransUNET, and Swin-UNET with optional ImageNet-trained backbones.

Primary LanguagePythonMIT LicenseMIT

keras-unet-collection

PyPI version PyPI license Maintenance

DOI

The tensorflow.keras implementation of U-net, V-net, U-net++, UNET 3+, Attention U-net, R2U-net, ResUnet-a, U^2-Net, TransUNET, and Swin-UNET with optional ImageNet-trained backbones.


keras_unet_collection.models contains functions that configure keras models with hyper-parameter options.

  • Pre-trained ImageNet backbones are supported for U-net, U-net++, UNET 3+, Attention U-net, and TransUNET.
  • Deep supervision is supported for U-net++, UNET 3+, and U^2-Net.
  • See the User guide for other options and use cases.
keras_unet_collection.models Name Reference
unet_2d U-net Ronneberger et al. (2015)
vnet_2d V-net (modified for 2-d inputs) Milletari et al. (2016)
unet_plus_2d U-net++ Zhou et al. (2018)
r2_unet_2d R2U-Net Alom et al. (2018)
att_unet_2d Attention U-net Oktay et al. (2018)
resunet_a_2d ResUnet-a Diakogiannis et al. (2020)
u2net_2d U^2-Net Qin et al. (2020)
unet_3plus_2d UNET 3+ Huang et al. (2020)
transunet_2d TransUNET Chen et al. (2021)
swin_unet_2d Swin-UNET Hu et al. (2021)

Note: the two Transformer models are incompatible with Numpy 1.20; NumPy 1.19.5 is recommended.


keras_unet_collection.base contains functions that build the base architecture (i.e., without model heads) of Unet variants for model customization and debugging.

keras_unet_collection.base Notes
unet_2d_base, vnet_2d_base, unet_plus_2d_base, unet_3plus_2d_base, att_unet_2d_base, r2_unet_2d_base, resunet_a_2d_base, u2net_2d_base, transunet_2d_base, swin_unet_2d_base Functions that accept an input tensor and hyper-parameters of the corresponded model, and produce output tensors of the base architecture.

keras_unet_collection.activations and keras_unet_collection.losses provide additional activation layers and loss functions.

keras_unet_collection.activations Name Reference
GELU Gaussian Error Linear Units (GELU) Hendrycks et al. (2016)
Snake Snake activation Liu et al. (2020)
keras_unet_collection.losses Name Reference
dice Dice loss Sudre et al. (2017)
tversky Tversky loss Hashemi et al. (2018)
focal_tversky Focal Tversky loss Abraham et al. (2019)
ms_ssim Multi-scale Structural Similarity Index loss Wang et al. (2003)
iou_seg Intersection over Union (IoU) loss for segmentation Rahman and Wang (2016)
iou_box (Generalized) IoU loss for object detection Rezatofighi et al. (2019)
triplet_1d Semi-hard triplet loss (experimental)
crps2d_tf CRPS loss (experimental)

Installation and usage

pip install keras-unet-collection

from keras_unet_collection import models
# e.g. models.unet_2d(...)
  • Note: Currently supported backbone models are: VGG[16,19], ResNet[50,101,152], ResNet[50,101,152]V2, DenseNet[121,169,201], and EfficientNetB[0-7]. See Keras Applications for details.

  • Note: Neural networks produced by this package may contain customized layers that are not part of the Tensorflow. It is reommended to save and load model weights.

  • Changelog

Examples

  • Jupyter notebooks are provided as examples:

    • Attention U-net with VGG16 backbone [link].

    • UNET 3+ with deep supervision, classification-guided module, and hybrid loss [link].

    • Vision-Transformer-based examples are in progress, and available at keras-vision-transformer

Dependencies

  • TensorFlow 2.5.0, Keras 2.5.0, Numpy 1.19.5.

  • (Optional for examples) Pillow, matplotlib, etc.

Overview

U-net is a convolutional neural network with encoder-decoder architecture and skip-connections, loosely defined under the concept of "fully convolutional networks." U-net was originally proposed for the semantic segmentation of medical images and is modified for solving a wider range of gridded learning problems.

U-net and many of its variants take three or four-dimensional tensors as inputs and produce outputs of the same shape. One technical highlight of these models is the skip-connections from downsampling to upsampling layers, which benefit the reconstruction of high-resolution, gridded outputs.

Contact

Yingkai (Kyle) Sha <yingkai@eoas.ubc.ca> <yingkaisha@gmail.com>

License

MIT License

Citation

@misc{keras-unet-collection,
  author = {Sha, Yingkai},
  title = {Keras-unet-collection},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/yingkaisha/keras-unet-collection}},
  doi = {10.5281/zenodo.5449801}
}