Opened this issue 3 years ago · 1 comments
pytorch/pytorch#62140
"grouped comm on a set of unflattened tensors can be more performant than flattening+a single flat nccl call."
Can also use allgather_coalesced instead of gradient/inverse broadcast.
allgather_coalesced