Defining a Sinkhorn Based Loss in Torch
Closed this issue · 4 comments
Hi all,
I am trying to define a loss in torch where I can compare an input and a target batch of PxP images. I want to treat these PxP images as 2d distributions, and I want to define an OT loss over them. Below is a snapshot of my try:
class SinkhornDist(nn.Module):
def __init__(self, batch_size, pixels, device):
super(SinkhornDist, self).__init__()
self.bs = batch_size
self.device = device
# Define the cost matrix (Euclidean distance) between the grid points
x, y = torch.meshgrid(torch.linspace(0, 1, pixels), torch.linspace(0, 1, pixels))
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = torch.cat([x.T.reshape(1, -1), y.T.reshape(1, -1)]).T
coordsSqr = torch.sum(coords**2, dim=1)
C = coordsSqr[:, None] + coordsSqr[None, :] - 2 * coords @ coords.T
C[C<0] = 0
self.C = torch.sqrt(C).to(device)
self.C = self.C/self.C.max()
# Define the Sinkhorn regularization parameter
# (larger values lead to faster but less accurate computations)
self.epsilon = 1.0
def forward(self, P_batch, Q_batch):
for i in range(self.bs):
if i == 0:
loss = ot.bregman.sinkhorn(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), self.C, self.epsilon)
loss += ot.bregman.sinkhorn(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), self.C, self.epsilon)
return loss
Unfortunately, I get the following error when I run this loss:
in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Basically, it seems as if os.sinkhorn2 doesn't allow differentiation. Does anybody know how I can resolve this error?
Thanks in advance!
Hello ot.sinkhorn2
allows differentiation but you use ot.sinkhorn
in your example so that might be the problem?
Thanks for the response @rflamary.
I checked, and unfortunately this is not the case.
I ran a sanity check, where I tried to see if the ot.sinkhorn2 gives me a differentiable object.
Here I generate random data with the same dimension as mine to test the differentiability:
P = torch.randn((1, 32, 32), requires_grad=True).cuda()
Q = torch.randn((1, 32, 32), requires_grad=False).cuda()
L = ot.sinkhorn2(P.view(-1, 1), Q.view(-1, 1), criterion.C, criterion.epsilon)
print(L.requires_grad)
I get a true
answer. Now, with my data I don't get a true answer:
L = ot.sinkhorn2(data[0][0].view(-1, 1), data[1][0].view(-1, 1), criterion.C, criterion.epsilon)
print(L.requires_grad)
the answer is false
, even though that just like before my data requires_grad:
print(data[0][0].requires_grad) #true
print(data[1][0].requires_grad) #false
what causes this issue you think?
My bad, it seems that after backpropagation my model produces tensors with nan
elements and this causes the sinkhorn function to give non-differentiable loss.
Good that you found your bug (an that it was not in POT). Closing the Issue