Mutual information for a volumetric image
Opened this issue · 1 comments
Hello!
The implementation looks very interesting and I was looking for a fully differentiable implementation of the mutual information calculation.
I am using 3D images, basically MRI images that are volumes and I wanna calculate the mutual information between two of them. The dimensions I have are (1, 256, 256, 256), where the first dimension is the batch size and the rest three are the width, depth and height.
I tried several ways to modify the current implementation but I failed. Below I attach my latest modification which doesn't yield correct results.
Any help would be highly appreciated!
Best,
Dimitris
class MutualInformation(nn.Module):
def __init__(self, sigma=0.1, num_bins=256, normalize=True):
super(MutualInformation, self).__init__()
self.sigma = sigma
self.num_bins = num_bins
self.normalize = normalize
self.epsilon = 1e-6
self.bins = nn.Parameter(
torch.linspace(0, 255, num_bins).float(), requires_grad=True
)
def marginalPdf(self, values):
residuals = values - self.bins.unsqueeze(0).unsqueeze(0)
kernel_values = torch.exp(-0.5 * (residuals / self.sigma).pow(2))
pdf = torch.mean(kernel_values, dim=1)
normalization = torch.sum(pdf, dim=1).unsqueeze(1) + self.epsilon
pdf = pdf / normalization
return pdf, kernel_values
def jointPdf(self, kernel_values1, kernel_values2):
joint_kernel_values = kernel_values1 * kernel_values2
normalization = (
torch.sum(joint_kernel_values, dim=(1, 2, 3, 4)).unsqueeze(1).unsqueeze(1)
+ self.epsilon
)
pdf = joint_kernel_values / normalization
return pdf
def getMutualInformation(self, input1, input2):
"""
input1: (1, 1, 256, 256, 256) tensor
input2: (1, 1, 256, 256, 256) tensor
return: scalar
"""
# Torch tensors for images between (0, 1)
input1 = input1 * 255
input2 = input2 * 255
assert input1.shape == input2.shape
pdf_x1, kernel_values1 = self.marginalPdf(input1)
pdf_x2, kernel_values2 = self.marginalPdf(input2)
pdf_x1x2 = self.jointPdf(kernel_values1, kernel_values2)
H_x1 = -torch.sum(pdf_x1 * torch.log2(pdf_x1 + self.epsilon), dim=(1, 2, 3))
H_x2 = -torch.sum(pdf_x2 * torch.log2(pdf_x2 + self.epsilon), dim=(1, 2, 3))
H_x1x2 = -torch.sum(
pdf_x1x2 * torch.log2(pdf_x1x2 + self.epsilon), dim=(1, 2, 3, 4)
)
mutual_information = H_x1 + H_x2 - H_x1x2
if self.normalize:
mutual_information = 2 * mutual_information / (H_x1 + H_x2)
return mutual_information
def forward(self, input1, input2):
"""
input1: B, C, H, W
input2: B, C, H, W
return: scalar
"""
return self.getMutualInformation(input1, input2)
I'm not too familiar with mutual information on volumetric data. If you have a link to a basic tutorial or paper, I can try and see what the difference is.
Naively speaking, would it suffice to just reshape the 3d tensor (B, C, H, W) -> (B, CHW)
or is the result different?