NVIDIA/Megatron-LM

[BUG] The gradient allreduce/reduce-scatter operation is performed twice when overlap_grad_reduce is False

Closed this issue · 2 comments

Describe the bug

def finish_grad_sync(self):
        """
        Finishes grad sync (all-reduce or reduce-scatter) communication operation
        for this bucket.

        When overlap_grad_reduce is set to True, waits for asynchronous communication
        call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
        """
        # If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
        if not self.overlap_grad_reduce:
            self.start_grad_sync()
            return
        assert self.communication_handle is not None and self.communication_issued, (
            f'Communication call has not been issued for this bucket '
            f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
        )
        self.communication_handle.wait()

The above code snippet is from the Bucket class in megatron/core/distributed/param_and_grad_buffer.py.

The Bucket class performs allreduce/reduce-scatter communication for the corresponding gradients using the start_grad_sync and finish_grad_sync functions. In this setup, start_grad_sync initiates the communication operation, while finish_grad_sync waits for the communication to complete. However, in synchronous communication mode (not overlap_grad_reduce), the finish_grad_sync function calls the start_grad_sync function again, resulting in two allreduce/reduce-scatter communication operations, which is not as expected.

To Reproduce
N/A

Expected behavior
N/A

Stack trace/logs
N/A

Environment (please complete the following information):

Proposed fix
N/A

Additional context
N/A

Hello, I don't believe this is an issue, since without overlapping grad reduce, we don't have the backward hook call start_grad_sync (and so start_grad_sync is only called once).

@deepakn94 Thanks for your replay, and you are right, start_grad_sync is only called when overlapping grad reduce.