calculate similarity for 3-dim input
dangvansam opened this issue · 1 comments
dangvansam commented
i have inputs are:
anchors: torch.Size([8, 20, 128])
positives: torch.Size([8, 20, 128])
negatives: torch.Size([100, 8, 20, 128])
8-batch size, 20-num pairs, 128-embedding dim, 100-num negative samples (shuffled from positive samples)
show, can i calculate similarity for 3d inputs with your code?
thanks
RElbers commented
Hi. This kind of input won't work, because the shapes don't match what the function expects.
The shapes should be: anchors : (N, D)
, positives : (N, D)
, negatives: (N, M, D)
.
If you combine the dimensions with 8 and 20, then it should work with InfoNCE(negative_mode='paired')
, but I'm not sure that is what you need. Let me know if it works for you.
anchors = torch.randn(8, 20, 128)
positives = torch.randn(8, 20, 128)
negatives = torch.randn(100, 8, 20, 128)
loss = InfoNCE(negative_mode='paired')
anchors = anchors.reshape(-1, 128)
positives = positives.reshape(-1, 128)
negatives = negatives.reshape(100, -1, 128).transpose(0,1)
output = loss(anchors, positives, negatives)