NVIDIA/Megatron-LM

[BUG] Passed the wrong type of argument to torch.distributed.broadcast.

Opened this issue · 0 comments

Describe the bug

def broadcast_params(self):
        """
        Syncs parameters across all DP ranks.
        """
        for param in self.module.parameters():
            is_expert_parallel = not getattr(param, 'allreduce', True)

            if is_expert_parallel:
                torch.distributed.broadcast(
                    param.data,
                    src=torch.distributed.get_process_group_ranks(self.expert_data_parallel_group),
                    group=self.expert_data_parallel_group,
                )
            else:
                torch.distributed.broadcast(
                    param.data,
                    src=torch.distributed.get_process_group_ranks(self.data_parallel_group),
                    group=self.data_parallel_group,
                )

The src parameter of torch.distributed.broadcast should be of type int, indicating the root from which to broadcast. However, in the above code, the passed parameter is a list of all ranks in the data parallel group.

The above code snippet is from the DistributedDataParallel class in megatron/core/distributed/distributed_data_parallel.py.

To Reproduce
N/A

Expected behavior
N/A

Stack trace/logs
N/A

Environment (please complete the following information):

Proposed fix
The rank-0 of the data parallel group should be passed in.

Additional context
N/A