RElbers/info-nce-pytorch

calculate similarity for 3-dim input

Closed this issue · 1 comments

i have inputs are:
anchors: torch.Size([8, 20, 128])
positives: torch.Size([8, 20, 128])
negatives: torch.Size([100, 8, 20, 128])

8-batch size, 20-num pairs, 128-embedding dim, 100-num negative samples (shuffled from positive samples)

show, can i calculate similarity for 3d inputs with your code?
thanks

Hi. This kind of input won't work, because the shapes don't match what the function expects.
The shapes should be: anchors : (N, D), positives : (N, D), negatives: (N, M, D).
If you combine the dimensions with 8 and 20, then it should work with InfoNCE(negative_mode='paired'), but I'm not sure that is what you need. Let me know if it works for you.

    anchors = torch.randn(8, 20, 128)
    positives = torch.randn(8, 20, 128)
    negatives = torch.randn(100, 8, 20, 128)
    
    loss = InfoNCE(negative_mode='paired')
    anchors = anchors.reshape(-1, 128)
    positives = positives.reshape(-1, 128)
    negatives = negatives.reshape(100, -1, 128).transpose(0,1)
    output = loss(anchors, positives, negatives)