IEIT-Yuan/Yuan-2.0

关于pp/tp/dp通信时间的推导

Rememberz opened this issue · 4 comments

阅读YUAN2.0论文,对公式(1)很感兴趣,但是没有看太懂PP通信时间(T2)、TP通信时间(T3)、和DP通信时间(T4)是如何推导出来的,特别是前面的系数,例如T2中的8,T3中的48,请问是否能够给出这三个通信时间详细的推导说明,非常感谢!
image

T2:双向通信数据,包括了前向传播接受的input hidden_states/发送的output hidden_states,反向传播接收output hidden_states grad/input hidden_states grad,使用混合精度训练,每个数据占2Byte
T3:张量并行的通信数据量,参考Megatron的论文,由于使用重计算,每个block的Attn和MLP都有allreduce,一共通信6次,allreduce通信的单向数据量2(N-1)/N*2Byte,由于我们使用的带宽是双向带宽,所以还要乘2。
T4:直接是参数allreduce通信的数据量,2(Ds-1)/Ds * 2 Bytes,由于使用的双向带宽,计算使用双向通信数据量,需要再乘2

T2:双向通信数据,包括了前向传播接受的input hidden_states/发送的output hidden_states,反向传播接收output hidden_states grad/input hidden_states grad,使用混合精度训练,每个数据占2Byte T3:张量并行的通信数据量,参考Megatron的论文,由于使用重计算,每个block的Attn和MLP都有allreduce,一共通信6次,allreduce通信的单向数据量2(N-1)/N*2Byte,由于我们使用的带宽是双向带宽,所以还要乘2。 T4:直接是参数allreduce通信的数据量,2(Ds-1)/Ds * 2 Bytes,由于使用的双向带宽,计算使用双向通信数据量,需要再乘2

非常感谢回复,很大帮助了我的理解,我还有个疑问,请问为什么使用双向带宽需要乘2呀?不太理解,可否详细解释一下,感谢!

以allreduce为例,接受和发送的通信量是相同的,2*(n-1)/n的通信量是单向的通信数据量,所以在使用双向带宽的时候要将通信数据量翻倍。
image