microsoft/DeepSpeed

How to turn on allgather overlapping in ZeRO-1/2 ?

Closed this issue · 5 comments

[ TARGET ]
ZeRO-1/2 requires an allgather after backward computations since the optimizer states are distributed and hence only parts of the model parameters are updated on each rank. allgather assembles the complete updated model for each rank before the forward operation in the next iteration/step (which relies on updated parameters). In theory, if this process is done bucket by bucket, allgather communications can be largely overlapped by the forward computation in next iteration. This is what I want to realize to further boost my training speed (together with reduce overlapping).

[ ISSUE ]
Even though I found related configurations on the official DS_CONFIG docs, I still failed to run a ZeRO-1/2 training with allgather overlapping. I came to this conclusion by carefully profiling the training and observe the allgather communication and forward computations. (btw the reduce overlap works well)
My ZeRO configurations are set as follows:

#------ ZeRO-1 DS_CONFIG ------#
{
"train_batch_size": 65536,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": true,
"loss_scale": 0
},
"zero_optimization": {
"stage": 1,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e6
}
}

#------ ZeRO-2 DS_CONFIG ------#
{
"train_batch_size": 65536,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": true,
"loss_scale": 0
},
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_scatter": true,
"reduce_bucket_size": 1e6,
"allgather_bucket_size": 1e6,
"allgather_partitions": true
}
}

[ BUG ]
After the failure I deep dive into the ZeRO-1/2 source code. In stage_1_and_2.py #L1921 an all_gather_dp_groups was call and allgather_bucket_size as set as one of this function's arguments. It seem that all_gather_dp_groups would do something with updated parameters and transmit it in buckets (which can make communication overlap possible).
However, in utils.py L#965 it called another function all_gather_into_tensor_dp_groups in which dist.all_gather_into_tensor is directly called without any bucketing.
I don't think this can result in any allgather communication overlapping since parameter allgather seems to transmit at once instead of bucket by bucket. I wonder how to turn on allgather overlapping and I doubt that if this is supported in ZeRO-1/2 ???

Thanks for your time and contributions :)

@2012zzhao, if I understand correctly, this is a request to overlap forward pass with parameter update all-gather? While this makes sense, we currently don't have bandwidth to implement this optimization. However, to help us plan and prioritize, can you provide some evidence of the expected performance benefits? For example, have you done some profiling to estimate the speedup of forward pass and e2e iteration? Thanks!

@2012zzhao, if I understand correctly, this is a request to overlap forward pass with parameter update all-gather? While this makes sense, we currently don't have bandwidth to implement this optimization. However, to help us plan and prioritize, can you provide some evidence of the expected performance benefits? For example, have you done some profiling to estimate the speedup of forward pass and e2e iteration? Thanks!

Thanks a lot for your reply and contributions.

Our intuition is similar with ZeRO-3 allgather overlap, where the allgather communication of updated model parameters is overlapped by the forward computation, so as to save more time. Even though ZeRO-2 allgather is not as frequent as that in ZeRO-3, I guess there is still uncovered communication time which can be reduced (by making overlap).

Anyway, I just wanted to make sure whether it was due to the design or the DS_CONFIG not set correctly. Based on your reply, it seems that there is no allgather-overlap design in DeepSpeed ZeRO-2 (stage_1_and_2.py), right?

btw, is it possible to simulate ZeRO-2 + allgahter-overlap by carefully setting DS_CONFIG of ZeRO-3 (e.g. keep most parameters static on device instead of partitioning them). I think this can be a straght-forward way to roughly profile and estimate its performance, perhaps. Maybe I can start from here to do some prototypes.

Again, thanks a lot for your reply and contributions. :)

  1. Correct, ZeRO-2 is not designed to overlap forward + param_update-allgather.
  2. You can simulate ZeRO-2 with ZeRO-3 by setting param_persistence_threshold in ds_config to a very large value, e.g., model size. Parameters smaller than this threshold are not partitioned.
  1. Correct, ZeRO-2 is not designed to overlap forward + param_update-allgather.正确,ZeRO-2 不是为向前重叠 + param_update-allgather 而设计的。
  2. You can simulate ZeRO-2 with ZeRO-3 by setting param_persistence_threshold in ds_config to a very large value, e.g., model size. Parameters smaller than this threshold are not partitioned.您可以通过将 ds_config 中的 param_persistence_threshold 设置为非常大的值(例如模型大小)来使用 ZeRO-2 和 ZeRO-3。小于此阈值的参数不会分区。

Now I am clear. Thank you sooooo much !

I will attach further results if I found it efficient and necessary. Thanks Ruwase :)

@2012zzhao, I will close this ticket. Feel free to update with your findings. Thanks!