FSDP2 incur higher GPU memory usage in 2D compare to FSDP1
wanchaol opened this issue · 5 comments
Recently found that when training in 2D, FSDP2 incurs much higher memory usage than FSDP1, triggering OOM issue for 70b model.
Some local test on H100 devgpu: llama 13B, global batch size 16 (local batch size 8), 2 way DP 4 way TP, selective op AC, shows a memory increase from 71G - 87G (16GB regression):
FSDP1:
Update: to remove complicating factors, if we use full AC instead of selective op AC, we would get the same regression.
The issue is that DTensor
's async funcols have recordStream
called on the collective tensors, holding onto their memory longer than they should. FSDP1's CPU rate limiter implicitly mitigated the recordStream
issues from TP, but FSDP2 does not have this CPU rate limiter anymore.
If we run the job with TORCH_NCCL_AVOID_RECORD_STREAMS=1
, then we see FSDP2's 2D use 68.97 GiB for the Llama-13B selective AC setup.
[rank0]:2024-04-04 08:02:19,445 - root - INFO - step: 1 loss: 10.8906 memory: 56.32GiB(59.26%) wps: 138 mfu: 1.23%
[rank0]:2024-04-04 08:02:19,445 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-04-04 08:02:54,632 - root - INFO - step: 10 loss: 9.4660 memory: 68.97GiB(72.57%) wps: 2,095 mfu: 18.68%
[rank0]:2024-04-04 08:03:24,819 - root - INFO - step: 20 loss: 7.9382 memory: 68.97GiB(72.57%) wps: 2,714 mfu: 24.19%
[rank0]:2024-04-04 08:03:55,005 - root - INFO - step: 30 loss: 7.1089 memory: 68.97GiB(72.57%) wps: 2,714 mfu: 24.19%
[rank0]:2024-04-04 08:04:25,228 - root - INFO - step: 40 loss: 6.6824 memory: 68.97GiB(72.57%) wps: 2,711 mfu: 24.16%
[rank0]:2024-04-04 08:04:55,514 - root - INFO - step: 50 loss: 6.7918 memory: 68.97GiB(72.57%) wps: 2,705 mfu: 24.11%
[rank0]:2024-04-04 08:05:25,759 - root - INFO - step: 60 loss: 6.5128 memory: 68.97GiB(72.57%) wps: 2,709 mfu: 24.15%
[rank0]:2024-04-04 08:05:55,993 - root - INFO - step: 70 loss: 6.2168 memory: 68.97GiB(72.57%) wps: 2,710 mfu: 24.15%
[rank0]:2024-04-04 08:06:26,297 - root - INFO - step: 80 loss: 6.0477 memory: 68.97GiB(72.57%) wps: 2,703 mfu: 24.10%
[rank0]:2024-04-04 08:06:56,527 - root - INFO - step: 90 loss: 5.9257 memory: 68.97GiB(72.57%) wps: 2,710 mfu: 24.16%
[rank0]:2024-04-04 08:07:27,035 - root - INFO - step: 100 loss: 5.8014 memory: 68.97GiB(72.57%) wps: 2,685 mfu: 23.94%
[rank0]:2024-04-04 08:08:02,221 - root - INFO - step: 110 loss: 5.7214 memory: 68.97GiB(72.57%) wps: 2,328 mfu: 20.75%
[rank0]:2024-04-04 08:08:32,497 - root - INFO - step: 120 loss: 5.6474 memory: 68.97GiB(72.57%) wps: 2,706 mfu: 24.12%
[rank0]:2024-04-04 08:09:02,770 - root - INFO - step: 130 loss: 5.6189 memory: 68.97GiB(72.57%) wps: 2,706 mfu: 24.12%
[rank0]:2024-04-04 08:09:33,034 - root - INFO - step: 140 loss: 5.6629 memory: 68.97GiB(72.57%) wps: 2,707 mfu: 24.13%
[rank0]:2024-04-04 08:10:03,782 - root - INFO - step: 150 loss: 5.5157 memory: 68.97GiB(72.57%) wps: 2,664 mfu: 23.75%
[rank0]:2024-04-04 08:10:34,085 - root - INFO - step: 160 loss: 5.4447 memory: 68.97GiB(72.57%) wps: 2,703 mfu: 24.10%
The MFU numbers look slightly lower though 😢
@gnadathur Sorry for the confusion from similar title. These are not duplicates. This one is for GPU memory, and the other is CPU memory 😅