facebookresearch/msn

What is the use of "AllReduce"?

yyk-wew opened this issue · 0 comments

Hello. Thank you for your great work!

I have some questions about the "AllReduce" class defined here.

msn/src/utils.py

Lines 226 to 241 in 4388dc1

class AllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
x = x.contiguous() / dist.get_world_size()
dist.all_reduce(x)
return x
@staticmethod
def backward(ctx, grads):
return grads

And it is used in gathering probs when computing me-max regularization.

msn/src/losses.py

Lines 70 to 72 in 4388dc1

if me_max:
avg_probs = AllReduce.apply(torch.mean(probs, dim=0))
rloss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))

I wonder why not use "dist.all_reduce(x)" directly. It seems that using "AllReduce" multiply the gradient by "world_size" times.
I want to know whether i am correct and why this makes sense.

Thx!