PythonOT/POT

How ot.dist work with sequence?

Closed this issue · 1 comments

Hello,

I consider OT a black box, so I may ask something stupid.

I'm following Wasserstein 2 Minibatch GAN with PyTorch to train my own model, but I got error. My input and output are sequences. Here is my code:

    ab = (torch.ones(batch_size) / batch_size).to(device)
    sgd = torch.optim.Adam(model.parameters(), lr=0.001)
    CE_loss = nn.CrossEntropyLoss(ignore_index = 41)
    for epoch in range(1000):
      logits, c_emb, t_emb = model(phonetic, linguistic, transcript)
      # print(logits.shape) #batch x classes x time
      # print(c_emb.shape) #batch x time x features
      # print(t_emb.shape) #batch x time x features
      M = ot.dist(c_emb, t_emb)
      loss_W = ot.emd2(ab, ab, M).to(device)
      loss_CE = CE_loss(logits, output)
      loss = loss_W + loss_CE
      loss.backward()
      sgd.step()
      sgd.zero_grad()

The error:

M = ot.dist(c_emb, t_emb)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 307, in dist
return euclidean_distances(x1, x2, squared=True)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 253, in euclidean_distances
a2 = nx.einsum('ij,ij->i', X, X)
File "/opt/conda/lib/python3.8/site-packages/ot/backend.py", line 1897, in einsum
return torch.einsum(subscripts, *operands)
File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

How can I use ot.dist with sequence correctly?

You seem to have a size problem here. On sequences you should write your own distance function in full pytorch so that you can have full backprop. Te function should return an nxm matrix where n is the number of sequences in c_emb and t_emb respectively. ot.dist works only between samples in vector format following numpy cdist API.

I'm converting this to a discussion since it does not seem a bug from POT.