connorlee77/pytorch-mutual-information

Mutual information for a volumetric image

Opened this issue · 1 comments

Hello!

The implementation looks very interesting and I was looking for a fully differentiable implementation of the mutual information calculation.

I am using 3D images, basically MRI images that are volumes and I wanna calculate the mutual information between two of them. The dimensions I have are (1, 256, 256, 256), where the first dimension is the batch size and the rest three are the width, depth and height.

I tried several ways to modify the current implementation but I failed. Below I attach my latest modification which doesn't yield correct results.

Any help would be highly appreciated!

Best,
Dimitris

class MutualInformation(nn.Module):

    def __init__(self, sigma=0.1, num_bins=256, normalize=True):
        super(MutualInformation, self).__init__()

        self.sigma = sigma
        self.num_bins = num_bins
        self.normalize = normalize
        self.epsilon = 1e-6

        self.bins = nn.Parameter(
            torch.linspace(0, 255, num_bins).float(), requires_grad=True
        )

    def marginalPdf(self, values):
        residuals = values - self.bins.unsqueeze(0).unsqueeze(0)
        kernel_values = torch.exp(-0.5 * (residuals / self.sigma).pow(2))

        pdf = torch.mean(kernel_values, dim=1)
        normalization = torch.sum(pdf, dim=1).unsqueeze(1) + self.epsilon
        pdf = pdf / normalization

        return pdf, kernel_values

    def jointPdf(self, kernel_values1, kernel_values2):
        joint_kernel_values = kernel_values1 * kernel_values2
        normalization = (
            torch.sum(joint_kernel_values, dim=(1, 2, 3, 4)).unsqueeze(1).unsqueeze(1)
            + self.epsilon
        )
        pdf = joint_kernel_values / normalization

        return pdf

    def getMutualInformation(self, input1, input2):
        """
        input1: (1, 1, 256, 256, 256) tensor
        input2: (1, 1, 256, 256, 256) tensor

        return: scalar
        """

        # Torch tensors for images between (0, 1)
        input1 = input1 * 255
        input2 = input2 * 255

        assert input1.shape == input2.shape

        pdf_x1, kernel_values1 = self.marginalPdf(input1)
        pdf_x2, kernel_values2 = self.marginalPdf(input2)
        pdf_x1x2 = self.jointPdf(kernel_values1, kernel_values2)

        H_x1 = -torch.sum(pdf_x1 * torch.log2(pdf_x1 + self.epsilon), dim=(1, 2, 3))
        H_x2 = -torch.sum(pdf_x2 * torch.log2(pdf_x2 + self.epsilon), dim=(1, 2, 3))
        H_x1x2 = -torch.sum(
            pdf_x1x2 * torch.log2(pdf_x1x2 + self.epsilon), dim=(1, 2, 3, 4)
        )

        mutual_information = H_x1 + H_x2 - H_x1x2

        if self.normalize:
            mutual_information = 2 * mutual_information / (H_x1 + H_x2)

        return mutual_information

    def forward(self, input1, input2):
        """
        input1: B, C, H, W
        input2: B, C, H, W

        return: scalar
        """
        return self.getMutualInformation(input1, input2)

I'm not too familiar with mutual information on volumetric data. If you have a link to a basic tutorial or paper, I can try and see what the difference is.

Naively speaking, would it suffice to just reshape the 3d tensor (B, C, H, W) -> (B, CHW) or is the result different?