A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet.
In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: pytorch/pytorch#13545).
def _kl_independent_independent(p, q):
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
raise NotImplementedError
result = kl_divergence(p.base_dist, q.base_dist)
return _sum_rightmost(result, p.reinterpreted_batch_ndims)
In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network
train_loader = define this yourself
net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta)
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0)
for epoch in range(epochs):
for step, (patch, mask) in enumerate(train_loader):
patch = patch.to(device)
mask = mask.to(device)
mask = torch.unsqueeze(mask,1)
net.forward(patch, mask, training=True)
elbo = net.elbo(mask)
reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
One of the datasets used in the original paper is the LIDC dataset. I've preprocessed this data and stored them in a pickle file, which you can download here. After downloading the files you should place them in a folder called 'data'. After that, you can train your own Probabilistic UNet on the LIDC dataset using the simple train script provided in train_model.py.
Train code
CUDA_VISIBLE_DEVICES=7 python3 train_model.py --dataset DRIVE --params params/DRIVE_train.json --train_batch 8
CUDA_VISIBLE_DEVICES=7 python3 train_model.py --dataset ROSE --params params/ROSE_train.json --train_batch 6
CUDA_VISIBLE_DEVICES=3 python3 train_model.py --dataset PARSE2D --params params/PARSE2D_train.json --train_batch 8
Test code
CUDA_VISIBLE_DEVICES=7 python3 infer.py --dataset DRIVE --params params/DRIVE_validation.json
CUDA_VISIBLE_DEVICES=7 python3 infer.py --dataset ROSE --params params/ROSE_validation.json
CUDA_VISIBLE_DEVICES=6 python3 infer_parse.py --dataset PARSE2D --params params/PARSE2D_validation.json