
PyTorch Implementation of Semantic Segmentation CNNs: This repository features key architectures like UNet, DeepLabv3+, SegNet, FCN, and PSPNet. It's crafted to provide a solid foundation for Semantic Segmentation tasks using PyTorch.

Primary LanguagePython

Simple PyTorch implementation of Semantic Segmentation CNNs

PyTorch Implementation of Semantic Segmentation CNNs: This repository features key architectures (from scratch) like UNet, DeepLabv3+, SegNet, FCN, and PSPNet. It's crafted to provide a solid foundation for Semantic Segmentation tasks using PyTorch. This implementation incorporates valuable contributions of the broader GitHub community, as detailed in the references section.

Supported Architectures and Backbones

  • UNet: No backbone needed.
  • DeepLabv3+: Support for ResNet backbones (ResNet18, ResNet34, ResNet50, and ResNet101).
  • PSPNet: Support for ResNet backbones (ResNet18, ResNet34, ResNet50, and ResNet101).
  • FCN: Support for VGG backbones (VGG11, VGG13, VGG16, and VGG19).
  • SegNet: No backbone needed.
UNet Architecture
UNet Architecture
DeepLabv3+ Architecture
DeepLabv3+ Architecture
SegNet Architecture
SegNet Architecture
FCN Architecture
FCN Architecture
PSPNet Architecture
PSPNet Architecture

Optimizers and Learning Rate Schedulers

  • Optimizers: Adam, SGD, and RMSprop.
  • Learning Rate Schedulers: StepLR, PolyLR, and ReduceLROnPlateau.

Loss Functions

  • Cross Entropy (variations):
    • Standard CE
    • CE with class weights
    • Focal Loss
  • Dice Loss
  • Joint loss: Conjunction of Dice loss with one of the CE variations (as a Joint Loss).

Evaluation Metrics

  • Models are evaluated using:
    • Dice Coefficient
    • Intersection over Union (IoU) score.


A mock dataset is included in the repository for demonstration and testing purposes. Note that this dataset is not aimed to be used for training/testing, but rather for setting up and debugging for the first run, a convenience.

Replace the mock dataset with your own dataset as needed. The data loader accepts images of arbitrary dimensions and resizes them to the target size. Ensure that your dataset follows the below directory structure for optimal compatibility with the data loader:

├── train
│ ├── images
│ └── masks
├── val
│ ├── images
│ └── masks

This structure includes separate subfolders for training and validation data, with further subdivisions for images and their corresponding masks.

When preparing your dataset, ensure that the images are in .jpg and masks are in .png format and are of the same size as their corresponding images. Each RGB pixel in the mask should represent a class label as an integer. For instance, in a dataset with 3 classes, use [0,0,0], [1,1,1], and [2,2,2] to label these classes.

For a better understanding of the expected data structure and mask format, please refer to the mock dataset included in this repository. The mock dataset serves as a practical example, demonstrating how your data and masks should be organized and formatted, use ./mock_dataset/inspect_data.py to visualize masks and images.

Edit utils/data_loading.py to modify behavior.


To run this project, you need to have the following packages installed:

  • torch
  • matplotlib
  • numpy
  • Pillow
  • tqdm
  • torchvision
  • opencv-python

You can install them by running the following command:

pip install -r requirements.txt

Alternatively, you can manually install each package using:

pip install torch matplotlib numpy Pillow tqdm torchvision


Modify the config.py file as needed, including dataset paths:

        self.batch_size = 2  # Batch size for training
        self.lr = 1e-5  # Learning rate
        self.optimization = 'RMSprop'  # Optimization method ('RMSprop', 'SGD', 'Adam')
        self.lr_policy = 'plateau'  # Learning rate policy ('plateau', 'poly', 'step')
        self.lr_decay_step = 0.5  # LR decay step for 'step' policy

        self.load = False  # Flag to load model from a .pth file
        self.val_frequency = 20  # Validation frequency as a percentage
        self.extra_weight_frequency = 10  # Frequency for saving extra weights as a percentage

        self.amp = False  # Use mixed precision training
        self.classes = 3  # Number of output classes
        self.target_size = (512, 512)  # Target size for input images (height, width)
        self.number_of_in_channels = 3  # Number of input channels

        self.loss_type = 'joint'  # Loss type ('dice', 'ce', 'joint')
        self.CE_variation = 'ce'  # Cross-entropy variation ('CE', 'CEW', 'Focal')
        self.class_weights = None  # Class weights for CE loss or a list of weights for each class e.g., [1.0, 1.0, 4.0] for [class1, class2, class3]

        # Directories for dataset and checkpoints
        self.dir_root = Path('F:/projects/semantic_segmentaion_archs_repo/mock_dataset')  # Root directory for dataset
        self.train_images_dir = Path(os.path.join(self.dir_root, 'train/images'))
        self.train_mask_dir = Path(os.path.join(self.dir_root, 'train/masks'))
        self.val_images_dir = Path(os.path.join(self.dir_root, 'val/images'))
        self.val_mask_dir = Path(os.path.join(self.dir_root, 'val/masks'))
        self.dir_checkpoint = Path('./checkpoints/')  # Directory for saving checkpoints

        # CNN architecture and backbone
        self.cnn_arch = 'deeplabv3+'  # CNN architecture ('UNet', 'DeepLab', 'SegNet', 'PSPNet', 'FCN')
        self.backbone = 'resnet18'  # Backbone model ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'vgg11', 'vgg13', 'vgg16', 'vgg19')
        self.use_pretrained_backbone = True  # Use a pretrained backbone
        self.bilinear = True  # Use bilinear upsampling

then Run train.py directly or Execute the training script using the following command:

python train.py

Note compatibility when creating segmentation models:

Architecture Compatible Backbones
FCN Variants
- FCN32s VGG11, VGG13, VGG16, VGG19
- FCN16s VGG11, VGG13, VGG16, VGG19
- FCN8s VGG11, VGG13, VGG16, VGG19
- FCNs VGG11, VGG13, VGG16, VGG19
DeepLabv3+ ResNet18, ResNet34, ResNet50, ResNet101
PSPNet ResNet18, ResNet34, ResNet50, ResNet101
UNet None (No backbone required)
SegNet None (No backbone required)


Modify the attributes of the Args class in predict.py file as needed, (including path for testing images, model weights, and prediction network parameters), then Run predict.py directly or Execute the prediction script using the following command:

python predict.py



Contributions to this repository are welcome. Feel free to submit a pull request or open an issue for any bugs or feature requests.


This repository is a collection of various Semantic Segmentation CNN implementations, each potentially under its own license. We respect the original licenses of all included works. Users are advised to refer to the respective original sources for licensing information specific to each implementation.

For any original contributions made in this repository, they are provided under the the terms of the MIT license.