gpauloski/kfac-pytorch

Use allreduce_coalesced for factor allreduce

Opened this issue · 1 comments

pytorch/pytorch#62140

"grouped comm on a set of unflattened tensors can be more performant than flattening+a single flat nccl call."

Can also use allgather_coalesced instead of gradient/inverse broadcast.