RElbers/info-nce-pytorch

i have one question for the part of code

Closed this issue · 7 comments

if negative_keys is not None:
# Explicit negative keys

    # Cosine between positive pairs
    positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True)

    # Cosine between all query-negative combinations
    negative_logits = query @ transpose(negative_keys)

    # First index in last dimension are the positive samples
    logits = torch.cat([positive_logit, negative_logits], dim=1)
    labels = torch.zeros(len(logits), dtype=torch.long, device=query.device)

1)why the labels all are zero, Shouldn't there be a positive sample pairs labeled 1?
2)Is this cosine similarity? It should be just inner product?

Hi.

  1. The positive samples are in the 0th index of the logits. So labels is just a list of all 0s.
  2. The vectors are first normalized and then the dot product is taken, which gives the cosine angle.

In theory that should be possible. You just need a measure which gives low values for positive pairs and a high values for negative pairs.

Sorry, what I said in my previous comment was wrong. We want high values (similarity) between positive pairs and low values for negative pairs. And to optimize this, we can simple use the categorical cross entropy.