facebookresearch/swav

Surprising but interesting duplicated clusters

Jiawei-Yang opened this issue · 1 comments

Hi, thanks for your brilliant work!

I have found an interesting fact that many of the learned prototypes are duplicated.

I began with the question "how well do all the learned prototypes scatter?". So I downloaded the best pre-trained model from this link.

Interestingly, when I computed the pairwise cosine similarity among all prototypes, multiple pairs of prototypes have similarity scores of 1's and turn out to be the same.

Here is the code.

import torch
import torch.nn.functional as F


model = torch.load('swav_800ep_pretrain.pth.tar', map_location='cpu')
protos = model['module.prototypes.weight'] # (3000,128)
similarity = protos @ protos.T  # (3000,3000)
non_diag = similarity - torch.eye(3000) # This matrix should contain pairwise cosine similarity of non-identical clusters.

# Take the first prototype as an example
values, indices = non_diag[0].sort()

>>> print(values[-10:])
tensor([0.5267, 0.5267, 0.5267, 0.6343, 0.6343, 0.6343, 0.6344, 1.0000, 1.0000,
        1.0000])
>>> print(indices[10:])
tensor([  91,  932, 1244,  ...,  937, 1819, 2363])
>>> print(protos[0])
tensor([-0.0873, -0.0289,  0.1113,  0.1079,  0.0845, -0.0683,  0.0359, -0.0891,
         0.1160,  0.0086,  0.0602, -0.0444, -0.0620, -0.0612, -0.1079, -0.0714,
        -0.1299,  0.0790, -0.0428,  0.0628,  0.0202,  0.0361,  0.0414,  0.1667,
        -0.1552, -0.0179,  0.1873,  0.1460,  0.1022,  0.0320,  0.0937, -0.0669,
        -0.0611, -0.0737,  0.0175, -0.1226, -0.1596, -0.0358, -0.1683,  0.0984,
        -0.0613, -0.1481, -0.0617,  0.0235, -0.0650,  0.0165,  0.0387, -0.0046,
         0.0452,  0.1192,  0.0245, -0.0401,  0.0379, -0.1750,  0.0459, -0.0631,
        -0.0551, -0.1220, -0.0352,  0.0178,  0.1306, -0.1511,  0.0077,  0.0274,
         0.0032, -0.0396, -0.1273,  0.0903, -0.1012,  0.0024,  0.1248,  0.1845,
         0.0089, -0.0679, -0.0545,  0.0511,  0.0709,  0.1274, -0.0679, -0.0050,
         0.0006, -0.1155,  0.1040,  0.0527, -0.1587,  0.1085,  0.0560,  0.0032,
        -0.0046, -0.0338, -0.0009, -0.0129, -0.0033,  0.0171,  0.0436, -0.0369,
         0.0274,  0.0577, -0.0145, -0.0775, -0.0514,  0.0060, -0.0040,  0.0495,
        -0.0725,  0.0900, -0.0259,  0.1121, -0.0870, -0.0796,  0.2087, -0.0425,
        -0.0596, -0.0466,  0.0146,  0.0791, -0.0211,  0.1539, -0.1551, -0.0358,
        -0.1216, -0.1700,  0.0100,  0.1275,  0.0419,  0.2577, -0.0983,  0.0249])
>>> print(protos[2363])
tensor([-0.0873, -0.0289,  0.1113,  0.1079,  0.0845, -0.0683,  0.0359, -0.0891,
         0.1160,  0.0086,  0.0602, -0.0444, -0.0620, -0.0612, -0.1079, -0.0714,
        -0.1299,  0.0790, -0.0428,  0.0628,  0.0202,  0.0361,  0.0414,  0.1667,
        -0.1552, -0.0179,  0.1873,  0.1460,  0.1022,  0.0320,  0.0937, -0.0669,
        -0.0611, -0.0737,  0.0175, -0.1226, -0.1596, -0.0358, -0.1683,  0.0984,
        -0.0613, -0.1481, -0.0617,  0.0235, -0.0650,  0.0165,  0.0387, -0.0046,
         0.0452,  0.1192,  0.0245, -0.0401,  0.0379, -0.1750,  0.0459, -0.0631,
        -0.0551, -0.1220, -0.0352,  0.0178,  0.1306, -0.1511,  0.0077,  0.0274,
         0.0032, -0.0396, -0.1273,  0.0902, -0.1012,  0.0024,  0.1248,  0.1845,
         0.0089, -0.0679, -0.0545,  0.0511,  0.0709,  0.1274, -0.0679, -0.0050,
         0.0006, -0.1155,  0.1040,  0.0527, -0.1587,  0.1085,  0.0560,  0.0032,
        -0.0046, -0.0338, -0.0009, -0.0129, -0.0033,  0.0171,  0.0436, -0.0369,
         0.0274,  0.0577, -0.0145, -0.0775, -0.0514,  0.0060, -0.0040,  0.0495,
        -0.0725,  0.0900, -0.0259,  0.1121, -0.0870, -0.0796,  0.2087, -0.0425,
        -0.0596, -0.0466,  0.0146,  0.0791, -0.0211,  0.1539, -0.1551, -0.0358,
        -0.1216, -0.1700,  0.0100,  0.1275,  0.0419,  0.2577, -0.0983,  0.0249])

So, prototypes #0, #937, #1819, #2363 are in fact identical.

I wonder if you had noticed this fact, and do you have any idea about why this is the case?

Best,
Jiawei

Hi @Jiawei-Yang, Thanks for your kind words and for sharing this finding ! I have not noticed that, I think it might explained why using more clusters have not a big impact on the performance. Feel free to post more if you have other findings/analysis :).