/pytorch-unet-1

Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation

Primary LanguageJupyter NotebookMIT LicenseMIT

UNet/FCN PyTorch

This repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al.

Synthetic images/masks for training

First clone the repository and cd into the project directory.

import matplotlib.pyplot as plt
import numpy as np
import helper
import simulation

# Generate some random images
input_images, target_masks = simulation.generate_random_data(192, 192, count=3)

for x in [input_images, target_masks]:
    print(x.shape)
    print(x.min(), x.max())

# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]

# Left: Input image (black and white), Right: Target mask (6ch)
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])

Left: Input image (black and white), Right: Target mask (6ch)

png

Prepare Dataset and DataLoader

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

class SimDataset(Dataset):
    def __init__(self, count, transform=None):
        self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)
        self.transform = transform

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)

        return [image, mask]

# use the same transformations for train/val in this example
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
])

train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)

image_datasets = {
    'train': train_set, 'val': val_set
}

batch_size = 25

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

Check the outputs from DataLoader

import torchvision.utils

def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)

    return inp

# Get a batch of training data
inputs, masks = next(iter(dataloaders['train']))

print(inputs.shape, masks.shape)

plt.imshow(reverse_transform(inputs[3]))
torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])

png

Create the UNet module

import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

Model summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(n_class=6)
model = model.to(device)

# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
            Conv2d-5         [-1, 64, 112, 112]           9,408
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
         MaxPool2d-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]           4,096
      BatchNorm2d-10           [-1, 64, 56, 56]             128
             ReLU-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]          16,384
      BatchNorm2d-16          [-1, 256, 56, 56]             512
           Conv2d-17          [-1, 256, 56, 56]          16,384
      BatchNorm2d-18          [-1, 256, 56, 56]             512
             ReLU-19          [-1, 256, 56, 56]               0
       Bottleneck-20          [-1, 256, 56, 56]               0
           Conv2d-21           [-1, 64, 56, 56]          16,384
      BatchNorm2d-22           [-1, 64, 56, 56]             128
             ReLU-23           [-1, 64, 56, 56]               0
           Conv2d-24           [-1, 64, 56, 56]          36,864
      BatchNorm2d-25           [-1, 64, 56, 56]             128
             ReLU-26           [-1, 64, 56, 56]               0
           Conv2d-27          [-1, 256, 56, 56]          16,384
      BatchNorm2d-28          [-1, 256, 56, 56]             512
             ReLU-29          [-1, 256, 56, 56]               0
       Bottleneck-30          [-1, 256, 56, 56]               0
           Conv2d-31           [-1, 64, 56, 56]          16,384
      BatchNorm2d-32           [-1, 64, 56, 56]             128
             ReLU-33           [-1, 64, 56, 56]               0
           Conv2d-34           [-1, 64, 56, 56]          36,864
      BatchNorm2d-35           [-1, 64, 56, 56]             128
             ReLU-36           [-1, 64, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          16,384
      BatchNorm2d-38          [-1, 256, 56, 56]             512
             ReLU-39          [-1, 256, 56, 56]               0
       Bottleneck-40          [-1, 256, 56, 56]               0
           Conv2d-41          [-1, 128, 56, 56]          32,768
      BatchNorm2d-42          [-1, 128, 56, 56]             256
             ReLU-43          [-1, 128, 56, 56]               0
           Conv2d-44          [-1, 128, 28, 28]         147,456
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 512, 28, 28]          65,536
      BatchNorm2d-48          [-1, 512, 28, 28]           1,024
           Conv2d-49          [-1, 512, 28, 28]         131,072
      BatchNorm2d-50          [-1, 512, 28, 28]           1,024
             ReLU-51          [-1, 512, 28, 28]               0
       Bottleneck-52          [-1, 512, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          65,536
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56          [-1, 128, 28, 28]         147,456
      BatchNorm2d-57          [-1, 128, 28, 28]             256
             ReLU-58          [-1, 128, 28, 28]               0
           Conv2d-59          [-1, 512, 28, 28]          65,536
      BatchNorm2d-60          [-1, 512, 28, 28]           1,024
             ReLU-61          [-1, 512, 28, 28]               0
       Bottleneck-62          [-1, 512, 28, 28]               0
           Conv2d-63          [-1, 128, 28, 28]          65,536
      BatchNorm2d-64          [-1, 128, 28, 28]             256
             ReLU-65          [-1, 128, 28, 28]               0
           Conv2d-66          [-1, 128, 28, 28]         147,456
      BatchNorm2d-67          [-1, 128, 28, 28]             256
             ReLU-68          [-1, 128, 28, 28]               0
           Conv2d-69          [-1, 512, 28, 28]          65,536
      BatchNorm2d-70          [-1, 512, 28, 28]           1,024
             ReLU-71          [-1, 512, 28, 28]               0
       Bottleneck-72          [-1, 512, 28, 28]               0
           Conv2d-73          [-1, 128, 28, 28]          65,536
      BatchNorm2d-74          [-1, 128, 28, 28]             256
             ReLU-75          [-1, 128, 28, 28]               0
           Conv2d-76          [-1, 128, 28, 28]         147,456
      BatchNorm2d-77          [-1, 128, 28, 28]             256
             ReLU-78          [-1, 128, 28, 28]               0
           Conv2d-79          [-1, 512, 28, 28]          65,536
      BatchNorm2d-80          [-1, 512, 28, 28]           1,024
             ReLU-81          [-1, 512, 28, 28]               0
       Bottleneck-82          [-1, 512, 28, 28]               0
           Conv2d-83          [-1, 256, 28, 28]         131,072
      BatchNorm2d-84          [-1, 256, 28, 28]             512
             ReLU-85          [-1, 256, 28, 28]               0
           Conv2d-86          [-1, 256, 14, 14]         589,824
      BatchNorm2d-87          [-1, 256, 14, 14]             512
             ReLU-88          [-1, 256, 14, 14]               0
           Conv2d-89         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
           Conv2d-91         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-92         [-1, 1024, 14, 14]           2,048
             ReLU-93         [-1, 1024, 14, 14]               0
       Bottleneck-94         [-1, 1024, 14, 14]               0
           Conv2d-95          [-1, 256, 14, 14]         262,144
      BatchNorm2d-96          [-1, 256, 14, 14]             512
             ReLU-97          [-1, 256, 14, 14]               0
           Conv2d-98          [-1, 256, 14, 14]         589,824
      BatchNorm2d-99          [-1, 256, 14, 14]             512
            ReLU-100          [-1, 256, 14, 14]               0
          Conv2d-101         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-102         [-1, 1024, 14, 14]           2,048
            ReLU-103         [-1, 1024, 14, 14]               0
      Bottleneck-104         [-1, 1024, 14, 14]               0
          Conv2d-105          [-1, 256, 14, 14]         262,144
     BatchNorm2d-106          [-1, 256, 14, 14]             512
            ReLU-107          [-1, 256, 14, 14]               0
          Conv2d-108          [-1, 256, 14, 14]         589,824
     BatchNorm2d-109          [-1, 256, 14, 14]             512
            ReLU-110          [-1, 256, 14, 14]               0
          Conv2d-111         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
            ReLU-113         [-1, 1024, 14, 14]               0
      Bottleneck-114         [-1, 1024, 14, 14]               0
          Conv2d-115          [-1, 256, 14, 14]         262,144
     BatchNorm2d-116          [-1, 256, 14, 14]             512
            ReLU-117          [-1, 256, 14, 14]               0
          Conv2d-118          [-1, 256, 14, 14]         589,824
     BatchNorm2d-119          [-1, 256, 14, 14]             512
            ReLU-120          [-1, 256, 14, 14]               0
          Conv2d-121         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-122         [-1, 1024, 14, 14]           2,048
            ReLU-123         [-1, 1024, 14, 14]               0
      Bottleneck-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 256, 14, 14]         262,144
     BatchNorm2d-126          [-1, 256, 14, 14]             512
            ReLU-127          [-1, 256, 14, 14]               0
          Conv2d-128          [-1, 256, 14, 14]         589,824
     BatchNorm2d-129          [-1, 256, 14, 14]             512
            ReLU-130          [-1, 256, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
            ReLU-133         [-1, 1024, 14, 14]               0
      Bottleneck-134         [-1, 1024, 14, 14]               0
          Conv2d-135          [-1, 256, 14, 14]         262,144
     BatchNorm2d-136          [-1, 256, 14, 14]             512
            ReLU-137          [-1, 256, 14, 14]               0
          Conv2d-138          [-1, 256, 14, 14]         589,824
     BatchNorm2d-139          [-1, 256, 14, 14]             512
            ReLU-140          [-1, 256, 14, 14]               0
          Conv2d-141         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-142         [-1, 1024, 14, 14]           2,048
            ReLU-143         [-1, 1024, 14, 14]               0
      Bottleneck-144         [-1, 1024, 14, 14]               0
          Conv2d-145          [-1, 512, 14, 14]         524,288
     BatchNorm2d-146          [-1, 512, 14, 14]           1,024
            ReLU-147          [-1, 512, 14, 14]               0
          Conv2d-148            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-149            [-1, 512, 7, 7]           1,024
            ReLU-150            [-1, 512, 7, 7]               0
          Conv2d-151           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-152           [-1, 2048, 7, 7]           4,096
          Conv2d-153           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-154           [-1, 2048, 7, 7]           4,096
            ReLU-155           [-1, 2048, 7, 7]               0
      Bottleneck-156           [-1, 2048, 7, 7]               0
          Conv2d-157            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-158            [-1, 512, 7, 7]           1,024
            ReLU-159            [-1, 512, 7, 7]               0
          Conv2d-160            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-161            [-1, 512, 7, 7]           1,024
            ReLU-162            [-1, 512, 7, 7]               0
          Conv2d-163           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-164           [-1, 2048, 7, 7]           4,096
            ReLU-165           [-1, 2048, 7, 7]               0
      Bottleneck-166           [-1, 2048, 7, 7]               0
          Conv2d-167            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-168            [-1, 512, 7, 7]           1,024
            ReLU-169            [-1, 512, 7, 7]               0
          Conv2d-170            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-171            [-1, 512, 7, 7]           1,024
            ReLU-172            [-1, 512, 7, 7]               0
          Conv2d-173           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-174           [-1, 2048, 7, 7]           4,096
            ReLU-175           [-1, 2048, 7, 7]               0
      Bottleneck-176           [-1, 2048, 7, 7]               0
          Conv2d-177           [-1, 1024, 7, 7]       2,098,176
            ReLU-178           [-1, 1024, 7, 7]               0
        Upsample-179         [-1, 1024, 14, 14]               0
          Conv2d-180          [-1, 512, 14, 14]         524,800
            ReLU-181          [-1, 512, 14, 14]               0
          Conv2d-182          [-1, 512, 14, 14]       7,078,400
            ReLU-183          [-1, 512, 14, 14]               0
        Upsample-184          [-1, 512, 28, 28]               0
          Conv2d-185          [-1, 512, 28, 28]         262,656
            ReLU-186          [-1, 512, 28, 28]               0
          Conv2d-187          [-1, 512, 28, 28]       4,719,104
            ReLU-188          [-1, 512, 28, 28]               0
        Upsample-189          [-1, 512, 56, 56]               0
          Conv2d-190          [-1, 256, 56, 56]          65,792
            ReLU-191          [-1, 256, 56, 56]               0
          Conv2d-192          [-1, 256, 56, 56]       1,769,728
            ReLU-193          [-1, 256, 56, 56]               0
        Upsample-194        [-1, 256, 112, 112]               0
          Conv2d-195         [-1, 64, 112, 112]           4,160
            ReLU-196         [-1, 64, 112, 112]               0
          Conv2d-197        [-1, 128, 112, 112]         368,768
            ReLU-198        [-1, 128, 112, 112]               0
        Upsample-199        [-1, 128, 224, 224]               0
          Conv2d-200         [-1, 64, 224, 224]         110,656
            ReLU-201         [-1, 64, 224, 224]               0
          Conv2d-202          [-1, 6, 224, 224]             390
================================================================
Total params: 40,549,382
Trainable params: 40,549,382
Non-trainable params: 0
----------------------------------------------------------------

Define the main training loop

from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

Training

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 6
model = ResNetUNet(num_class).to(device)

# freeze backbone layers
#for l in model.base_layers:
#    for param in l.parameters():
#        param.requires_grad = False

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60)
cuda:0
Epoch 0/59
----------
LR 0.0001
train: bce: 0.070256, dice: 0.856320, loss: 0.463288
val: bce: 0.014897, dice: 0.515814, loss: 0.265356
saving best model
0m 51s
Epoch 1/59
----------
LR 0.0001
train: bce: 0.011369, dice: 0.309445, loss: 0.160407
val: bce: 0.003790, dice: 0.113682, loss: 0.058736
saving best model
0m 51s
Epoch 2/59
----------
LR 0.0001
train: bce: 0.003480, dice: 0.089928, loss: 0.046704
val: bce: 0.002525, dice: 0.067604, loss: 0.035064
saving best model
0m 51s

(Omitted)

Epoch 57/59
----------
LR 1e-05
train: bce: 0.000523, dice: 0.010289, loss: 0.005406
val: bce: 0.001558, dice: 0.030965, loss: 0.016261
0m 51s
Epoch 58/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010209, loss: 0.005364
val: bce: 0.001548, dice: 0.031034, loss: 0.016291
0m 51s
Epoch 59/59
----------
LR 1e-05
train: bce: 0.000518, dice: 0.010168, loss: 0.005343
val: bce: 0.001566, dice: 0.030785, loss: 0.016176
0m 50s
Best val loss: 0.016171

Use the trained model

import math

model.eval()   # Set model to the evaluation mode

# Create another simulation dataset for test
test_dataset = SimDataset(3, transform = trans)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)

# Get the first batch
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

# Predict
pred = model(inputs)
# The loss functions include the sigmoid function.
pred = F.sigmoid(pred)
pred = pred.data.cpu().numpy()
print(pred.shape)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]

# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]

helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
(3, 6, 192, 192)

Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask

png