facebookresearch/dino

normalize knn probabilities

clemsgrs opened this issue · 0 comments

In the following part of the code, the retrieved probabilities are not normalised (hence they're not really probabilities):

dino/eval_knn.py

Lines 166 to 172 in cb71140

probs = torch.sum(
torch.mul(
retrieval_one_hot.view(batch_size, -1, num_classes),
distances_transform.view(batch_size, -1, 1),
),
1,
)

I'd suggest adding a quick normalising step right after to ensure they are actually probabilities (values between 0 and 1 & summing to 1):

probs = probs / probs.sum(dim=-1).unsqueeze(-1)

it's more a "nice to have" than an actual needed change