Code required to compute the JUMBOT loss
Closed this issue · 3 comments
KevinMusgrave commented
I'm considering adding JUMBOT to pytorch-adapt.
Does this snippet of code contain all the logic for computing the JUMBOT loss?
Lines 87 to 108 in 913624e
kilianFatras commented
Hi Kevin,
Thank you for your interest. I would love to see JUMBOT in the toolbox. Yes it does! It is basically it :)
KevinMusgrave commented
Thanks Kilian!
kilianFatras commented
One last thing, my code is based on Python OT (POT) 0.7 and works well. However some change were made in POT 0.8 which leads the current code to not work. The issue is due to POT 0.8 which uses Float32 Tensor while this code uses Float64 tensor. The following change fixes the issue under POT 0.8:
pi = ot.unbalanced.sinkhorn_knopp_unbalanced(a.double(), b.double(),
total_cost.double().detach().cpu().numpy(),
self.epsilon, self.tau)