How ot.dist work with sequence?
Closed this issue · 1 comments
Hello,
I consider OT a black box, so I may ask something stupid.
I'm following Wasserstein 2 Minibatch GAN with PyTorch to train my own model, but I got error. My input and output are sequences. Here is my code:
ab = (torch.ones(batch_size) / batch_size).to(device)
sgd = torch.optim.Adam(model.parameters(), lr=0.001)
CE_loss = nn.CrossEntropyLoss(ignore_index = 41)
for epoch in range(1000):
logits, c_emb, t_emb = model(phonetic, linguistic, transcript)
# print(logits.shape) #batch x classes x time
# print(c_emb.shape) #batch x time x features
# print(t_emb.shape) #batch x time x features
M = ot.dist(c_emb, t_emb)
loss_W = ot.emd2(ab, ab, M).to(device)
loss_CE = CE_loss(logits, output)
loss = loss_W + loss_CE
loss.backward()
sgd.step()
sgd.zero_grad()
The error:
M = ot.dist(c_emb, t_emb)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 307, in dist
return euclidean_distances(x1, x2, squared=True)
File "/opt/conda/lib/python3.8/site-packages/ot/utils.py", line 253, in euclidean_distances
a2 = nx.einsum('ij,ij->i', X, X)
File "/opt/conda/lib/python3.8/site-packages/ot/backend.py", line 1897, in einsum
return torch.einsum(subscripts, *operands)
File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
How can I use ot.dist with sequence correctly?
You seem to have a size problem here. On sequences you should write your own distance function in full pytorch so that you can have full backprop. Te function should return an nxm matrix where n is the number of sequences in c_emb
and t_emb
respectively. ot.dist
works only between samples in vector format following numpy cdist API.
I'm converting this to a discussion since it does not seem a bug from POT.