关于pp/tp/dp通信时间的推导
Rememberz opened this issue · 4 comments
T2:双向通信数据,包括了前向传播接受的input hidden_states/发送的output hidden_states,反向传播接收output hidden_states grad/input hidden_states grad,使用混合精度训练,每个数据占2Byte
T3:张量并行的通信数据量,参考Megatron的论文,由于使用重计算,每个block的Attn和MLP都有allreduce,一共通信6次,allreduce通信的单向数据量2(N-1)/N*2Byte,由于我们使用的带宽是双向带宽,所以还要乘2。
T4:直接是参数allreduce通信的数据量,2(Ds-1)/Ds * 2 Bytes,由于使用的双向带宽,计算使用双向通信数据量,需要再乘2
T2:双向通信数据,包括了前向传播接受的input hidden_states/发送的output hidden_states,反向传播接收output hidden_states grad/input hidden_states grad,使用混合精度训练,每个数据占2Byte T3:张量并行的通信数据量,参考Megatron的论文,由于使用重计算,每个block的Attn和MLP都有allreduce,一共通信6次,allreduce通信的单向数据量2(N-1)/N*2Byte,由于我们使用的带宽是双向带宽,所以还要乘2。 T4:直接是参数allreduce通信的数据量,2(Ds-1)/Ds * 2 Bytes,由于使用的双向带宽,计算使用双向通信数据量,需要再乘2
非常感谢回复,很大帮助了我的理解,我还有个疑问,请问为什么使用双向带宽需要乘2呀?不太理解,可否详细解释一下,感谢!