/Probabilistic-Unet-Pytorch

A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch

Primary LanguagePythonApache License 2.0Apache-2.0

Probabilistic UNet in PyTorch

A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet.

Adding KL divergence for Independent distribution

In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: pytorch/pytorch#13545).

def _kl_independent_independent(p, q):
    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
        raise NotImplementedError
    result = kl_divergence(p.base_dist, q.base_dist)
    return _sum_rightmost(result, p.reinterpreted_batch_ndims)

Training

In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network

train_loader = define this yourself
net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
for epoch in range(epochs):
    for step, (patch, mask) in enumerate(train_loader): 
        patch = patch.to(device)
        mask = mask.to(device)
        mask = torch.unsqueeze(mask,1)
        net.forward(patch, mask, training=True)
        elbo = net.elbo(mask)
        reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
        loss = -elbo + 1e-5 * reg_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Train on LIDC Dataset

One of the datasets used in the original paper is the LIDC dataset. I've preprocessed this data and stored them in a pickle file, which you can download here. After downloading the files you should place them in a folder called 'data'. After that, you can train your own Probabilistic UNet on the LIDC dataset using the simple train script provided in train_model.py.

Modified by Saumya using Structural DMT code

Train code

DRIVE

CUDA_VISIBLE_DEVICES=7 python3 train_model.py --dataset DRIVE --params params/DRIVE_train.json --train_batch 8

ROSE

CUDA_VISIBLE_DEVICES=7 python3 train_model.py --dataset ROSE --params params/ROSE_train.json --train_batch 6

PARSE2D

CUDA_VISIBLE_DEVICES=3 python3 train_model.py --dataset PARSE2D --params params/PARSE2D_train.json --train_batch 8

Test code

DRIVE

CUDA_VISIBLE_DEVICES=7 python3 infer.py --dataset DRIVE --params params/DRIVE_validation.json

ROSE

CUDA_VISIBLE_DEVICES=7 python3 infer.py --dataset ROSE --params params/ROSE_validation.json

PARSE

CUDA_VISIBLE_DEVICES=6 python3 infer_parse.py --dataset PARSE2D --params params/PARSE2D_validation.json