PythonOT/POT

Defining a Sinkhorn Based Loss in Torch

Closed this issue · 4 comments

Hi all,

I am trying to define a loss in torch where I can compare an input and a target batch of PxP images. I want to treat these PxP images as 2d distributions, and I want to define an OT loss over them. Below is a snapshot of my try:

class SinkhornDist(nn.Module):
    def __init__(self, batch_size, pixels, device):
        super(SinkhornDist, self).__init__()
        self.bs     = batch_size
        self.device = device
        
        # Define the cost matrix (Euclidean distance) between the grid points
        x, y   = torch.meshgrid(torch.linspace(0, 1, pixels), torch.linspace(0, 1, pixels))

        # Compute pariwise distances between points on 2D grid so we know
        # how to score the Wasserstein distance
        coords    = torch.cat([x.T.reshape(1, -1), y.T.reshape(1, -1)]).T
        coordsSqr = torch.sum(coords**2, dim=1)
        C         = coordsSqr[:, None] + coordsSqr[None, :] - 2 * coords @ coords.T
        C[C<0]    = 0
        self.C    = torch.sqrt(C).to(device)
        self.C    = self.C/self.C.max()
        
        # Define the Sinkhorn regularization parameter 
        # (larger values lead to faster but less accurate computations)
        self.epsilon = 1.0
        
    def forward(self, P_batch, Q_batch):
        
        for i in range(self.bs):
            if i == 0:
                loss = ot.bregman.sinkhorn(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), self.C, self.epsilon)
            loss += ot.bregman.sinkhorn(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), self.C, self.epsilon)
            
        return loss

Unfortunately, I get the following error when I run this loss:

in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Basically, it seems as if os.sinkhorn2 doesn't allow differentiation. Does anybody know how I can resolve this error?

Thanks in advance!

Hello ot.sinkhorn2 allows differentiation but you use ot.sinkhorn in your example so that might be the problem?

Thanks for the response @rflamary.

I checked, and unfortunately this is not the case.
I ran a sanity check, where I tried to see if the ot.sinkhorn2 gives me a differentiable object.

Here I generate random data with the same dimension as mine to test the differentiability:

P = torch.randn((1, 32, 32), requires_grad=True).cuda()
Q = torch.randn((1, 32, 32), requires_grad=False).cuda()
L = ot.sinkhorn2(P.view(-1, 1), Q.view(-1, 1), criterion.C, criterion.epsilon)

print(L.requires_grad)

I get a true answer. Now, with my data I don't get a true answer:

L = ot.sinkhorn2(data[0][0].view(-1, 1), data[1][0].view(-1, 1), criterion.C, criterion.epsilon)

print(L.requires_grad)

the answer is false, even though that just like before my data requires_grad:

print(data[0][0].requires_grad) #true
print(data[1][0].requires_grad) #false

what causes this issue you think?

My bad, it seems that after backpropagation my model produces tensors with nan elements and this causes the sinkhorn function to give non-differentiable loss.

Good that you found your bug (an that it was not in POT). Closing the Issue