kilianFatras/JUMBOT

Code required to compute the JUMBOT loss

Closed this issue · 3 comments

I'm considering adding JUMBOT to pytorch-adapt.

Does this snippet of code contain all the logic for computing the JUMBOT loss?

JUMBOT/digits/jumbot.py

Lines 87 to 108 in 913624e

### Ground cost
embed_cost = torch.cdist(g_xs_mb, g_xt_mb)**2
ys = F.one_hot(ys, num_classes=self.n_class).float()
t_cost = - torch.mm(ys, torch.transpose(torch.log(pred_xt), 0, 1))
total_cost = self.eta1 * embed_cost + self.eta2 * t_cost
#OT computation
a, b = ot.unif(g_xs_mb.size()[0]), ot.unif(g_xt_mb.size()[0])
pi = ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, total_cost.detach().cpu().numpy(),
self.epsilon, self.tau)
# To get DeepJDOT (https://arxiv.org/abs/1803.10081) comment the line above
# and uncomment the following line:
#pi = ot.emd(a, b, total_cost.detach().cpu().numpy())
pi = torch.from_numpy(pi).float().cuda()
# train the model
optimizer_g.zero_grad()
optimizer_f.zero_grad()
da_loss = torch.sum(pi * total_cost)

Hi Kevin,

Thank you for your interest. I would love to see JUMBOT in the toolbox. Yes it does! It is basically it :)

Thanks Kilian!

One last thing, my code is based on Python OT (POT) 0.7 and works well. However some change were made in POT 0.8 which leads the current code to not work. The issue is due to POT 0.8 which uses Float32 Tensor while this code uses Float64 tensor. The following change fixes the issue under POT 0.8:

 pi = ot.unbalanced.sinkhorn_knopp_unbalanced(a.double(), b.double(), 
                                        total_cost.double().detach().cpu().numpy(), 
                                        self.epsilon, self.tau)