facebookresearch/swav

Why matrix multiplication when using queue?

karray opened this issue · 0 comments

I'm trying to get my head around the multiplication of embeddings by model.module.prototypes.weight.t() before passing them into distributed_sinkhorn. It is not clear why do we do this multiplication only when using a queue (Line 298). Moreover, why do we mix output prototypes with multiplied embeddings (Line 301)? What's the motivation behind this?

swav/main_swav.py

Lines 294 to 307 in 06b1b7c

# time to use the queue
if queue is not None:
if use_the_queue or not torch.all(queue[i, -1, :] == 0):
use_the_queue = True
out = torch.cat((torch.mm(
queue[i],
model.module.prototypes.weight.t()
), out))
# fill the queue
queue[i, bs:] = queue[i, :-bs].clone()
queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs]
# get assignments
q = distributed_sinkhorn(out)[-bs:]