Why ShardedDDP and OSS are slower than Vanilla DDP
powermano opened this issue · 0 comments
powermano commented
I have test the https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py using two 3080ti and 4080ti respectively.
As mentioned in https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html
The training process can be modified from that carried out by DDP as follows:
1. The wrapped optimizer shards the optimizer state in a greedy fashion based on the parameter size but not the order in which it is used. This is to ensure that each rank has almost the same optimizer memory footprint.
2. The training process is similar to that used by PyTorch’s Distributed Data Parallel (DDP). The forward pass completes on each of the ranks followed by the backward pass. During the backward pass, gradients are synchronized using allreduce.
3. Each rank updates the parameters for the shard of optimizer state that it is responsible for and then discards the rest.
4. After update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter values.
OSS is very useful when you are using an optimizer such as Adam that has additional state. The wrapping of the optimizer is a one-line non intrusive change that provides memory savings.
If you are using SGD or any optimizer with a limited memory footprint, it is likely that you will see a slowdown when using multiple nodes, due to the additional communication in step 4. There is also some wasteful memory used to store gradients during allreduce in step 2 that is then discarded, although this also happens with normal PyTorch (nothing extraneous here).
Compared to DDP, the OSS + DDP has the additional communication in step 4, why On a single node, OSS should be always faster than vanilla PyTorch ?.
Performance tips for fairscale.optim.oss
1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending on the optimizer being used
3080ti
Optimizer | Median Throughput (img/s) (rank 0) | Peak Memory (MB) |
---|---|---|
Vanilla | 1795.03 +/- 34.88 | 1462.5MiB |
OSS + DDP | 1645.64 +/- 31.78 | 1290.0MiB |
OSS + ShardedDDP | 1468.54 +/- 12.97 | 1049.7MiB |
4080ti (set export NCCL_P2P_DISABLE=1, as this is a issue about the nvidia Driver and has not been solved.) :
Optimizer | Median Throughput (img/s) (rank 0) | Peak Memory (MB) |
---|---|---|
Vanilla | 2117.12 +/- 16.13 | 1556.4MiB |
OSS + DDP | 1850.65 +/- 5.97 | 1377.8MiB |
OSS + ShardedDDP | 1530.15 +/- 8.69 | 1158.6MiB |