/DenseUNet-pytorch

A Pytorch implementation of U-Net using a DenseNet-121 backbone

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

DenseUNet-pytorch

A PyTorch implementation of U-Net using a DenseNet-121 backbone for the encoding and deconding path.

The DenseNet blocks are based on the implementation available in torchvision.

The input is restricted to RGB images and has shape formula. The output has shape formula, where formula is the number of output classes.

If the downsaple option is set to False the stride in conv0 is set to 1 and pool0 is removed.

Optionally a pretrained model can be used to initalize the encoder.

Requirements

  • pytorch
  • torchvision

Usage

from dense_unet import DenseUNet

pretrained_encoder_uri = 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
#
# for a local file use
#
# from pathlib import Path
# pretrained_encoder_uri = Path('/path/to/local/model.pth').resolve().as_uri()
#

num_output_classes = 3
model = DenseUNet(num_output_classes, downsample=True, pretrained_encoder_uri=pretrained_encoder_uri)