[BUG] The gradient allreduce/reduce-scatter operation is performed twice when overlap_grad_reduce is False
Closed this issue · 2 comments
Describe the bug
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, waits for asynchronous communication
call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.overlap_grad_reduce:
self.start_grad_sync()
return
assert self.communication_handle is not None and self.communication_issued, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.communication_handle.wait()
The above code snippet is from the Bucket
class in megatron/core/distributed/param_and_grad_buffer.py.
The Bucket
class performs allreduce/reduce-scatter
communication for the corresponding gradients using the start_grad_sync
and finish_grad_sync
functions. In this setup, start_grad_sync
initiates the communication operation, while finish_grad_sync
waits for the communication to complete. However, in synchronous communication mode (not overlap_grad_reduce), the finish_grad_sync
function calls the start_grad_sync
function again, resulting in two allreduce/reduce-scatter
communication operations, which is not as expected.
To Reproduce
N/A
Expected behavior
N/A
Stack trace/logs
N/A
Environment (please complete the following information):
- Megatron-LM commit ID: f3a3020
Proposed fix
N/A
Additional context
N/A
Hello, I don't believe this is an issue, since without overlapping grad reduce, we don't have the backward hook call start_grad_sync
(and so start_grad_sync
is only called once).
@deepakn94 Thanks for your replay, and you are right, start_grad_sync is only called when overlapping grad reduce.