alibaba/Megatron-LLaMA

大家好,请教一个关于GLOBAL_BATCH_SIZE值计算的问题,希望大家不吝赐教。

Opened this issue · 1 comments

1、训练脚本中有个公式GLOBAL_BATCH_SIZE=$((($WORLD_SIZE * $MICRO_BATCH_SIZE) / ($TP_SIZE * $PP_SIZE) * 8))
在这个公式中,最后乘以8是代表什么意思?

2、还有另外如果不使用脚本中的GLOBAL_BATCH_SIZE计算公式,将GLOBAL_BATCH_SIZE赋固定的值,这个值设置大或设置小有什么影响吗?在设置值这方面有什么讲究吗?

1、公式中的8是梯度累计值,你也可以调整成≥1的任意整数值。
2、GLOBAL_BATCH_SIZE的计算公式,本质上是拿【数据并行组大小 DP_SIZE】x【MICRO_BATCH_SIZE】x【梯度累计值】。如果你自行给GLOBAL_BATCH_SIZE赋值,那么依据上式反推出来的【梯度累计值】得是一个≥1的任意整数值;否则会报错