facebookresearch/swav

Question on KMeans clustering in deepclusterv2

cheneeheng opened this issue · 2 comments

Hi there,

swav/main_deepclusterv2.py

Lines 362 to 364 in 5e073db

# E step
dot_products = torch.mm(local_memory_embeddings[j], centroids.t())
_, local_assignments = dot_products.max(dim=1)

The dot product is performed to calculate cosine similarity of the embs on a unit sphere.
I am wondering why .max() is used instead of .min() ? Or did I misunderstood something?

Thanks!

DC95 commented

Hi @cheneeheng and

I thought I might try answering -

The dimension of dot_products is [no of data points, no of clusters]
_ has the maximum value of cluster assigned for each data point and has a dimension of [no of data points]
local_assignments has the index value of the maximum assigned cluster, dimension = [no of data points]

max() is being used for collecting the maximum cluster value from the no of clusters

@mathildecaron31, Kindly correct me if I am wrong

Hi there,

Finally got some time to read up on spherical clustering today, and I believe I have some misunderstanding on cosine similarity.
For some reasons, in my head, I kept visualizing the dot product as the angle between 2 vectors, which is not the case.
The dot product is the result of cos(...), and dot product of 2 unit vector with ~same angle will have a value of 1.
Since 1 is the maximum value for cos(...), we need to find dot product with the largest value for clustering, thus the code is correct.

For those who might have the same question, some sources that I was reading:
Efficient online spherical k-means clustering
Spherical k-Means Clustering

@DC95 thanks for having a look.

Thanks!

Closing the issue.