meta-llama/llama-recipes

CUDA OOM during ckpt saving for Llama2-70b

lwmlyy opened this issue ยท 16 comments

lwmlyy commented

Hi, I am using 8*a100-80gb to lora-finetune Llama2-70b, the training and evaluation during epoch-1 went well, but went OOM when saving the peft model. The nightly version of pytorch is used.

The following command is used:
torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --low_cpu_fsdp --model_name ../Llama-2-70b-chat-hf --micro_batch_size 1 --batch_size_training 1 --dist_checkpoint_root_folder ../Llama-2-70b-chat-hf/ --dist_checkpoint_folder fine-tuned --use_peft --peft_method lora --lr 3e-4 --epoch 2 --pure_bf16 --alpaca_dataset --output_dir llama-70b-lorawallsft

"we are about to save the PEFT modules", it went CUDA OOM after this log is printed.

I have the same problem. I think everything is brought back to the first rank before saving and it causes CUDA OOM. I was able to save one .distcp file per GPU but I'm not sure how to get just the LoRA adapter file from there...

gongy commented

I'm also running into this (albeit with 4 A100 80GB). Wondering if there is a way we can work around it - happy to make a contribution if the direction is clear.

Seems like a shame to have this bug during save_pretrained when the rest of the training and evaluation works well.

  File "/opt/conda/lib/python3.9/site-packages/llama_recipes/utils/train_utils.py", line 142, in train
    model.save_pretrained(train_config.output_dir)
  File "/opt/conda/lib/python3.9/site-packages/peft/peft_model.py", line 167, in save_pretrained
    output_state_dict = get_peft_model_state_dict(
  File "/opt/conda/lib/python3.9/site-packages/peft/utils/save_and_load.py", line 41, in get_peft_model_state_dict
    state_dict = model.state_dict()
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1898, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  [Previous line repeated 8 more times]
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1894, in state_dict
    hook(self, prefix, keep_vars)
  File "/opt/conda/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 774, in _pre_state_dict_hook
    _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 293, in _full_pre_state_dict_hook
    _common_unshard_pre_state_dict_hook(
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 157, in _common_unshard_pre_state_dict_hook
    _enter_unshard_params_ctx(
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 118, in _enter_unshard_params_ctx
    fsdp_state._unshard_params_ctx[module].__enter__()
  File "/opt/conda/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 196, in _unshard_fsdp_state_params
    _unshard(state, handle, computation_stream, computation_stream)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 329, in _unshard
    handle.unshard()
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py", line 1250, in unshard
    unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/fsdp/flat_param.py", line 1276, in _alloc_padded_unsharded_flat_param
    _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
  File "/opt/conda/lib/python3.9/site-packages/torch/distributed/utils.py", line 166, in _alloc_storage
    tensor._typed_storage()._resize_(size.numel())
  File "/opt/conda/lib/python3.9/site-packages/torch/storage.py", line 921, in _resize_
    self._untyped_storage.resize_(size * self._element_size())
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 1 has a total capacty of 79.18 GiB of which 2.31 MiB is free. Process 138368 has 79.18 GiB memory in use. Of the allocated memory 76.72 GiB is allocated by PyTorch, and 245.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF```

Any updates on this issue? I'm encountering the same problem with 8 x A100 (80GB) for Lora 70B

I'm facing the same issue. Any workaround?

Also facing the same issue here. I noticed another thread facing a similar issue using LoRa fine-tuning (although with another model): philschmid/deep-learning-pytorch-huggingface#16
Seemed to be caused by peft versions after 0.2.0. Could this be related?

gongy commented

I found a workaround which involves allowing CPU offloading during the phase of saving the state dict.

I tested that end-to-end 70B training works with checkpointing on this repo.

I will try to find the time to merge in the changes soon, but you can find them here https://github.com/modal-labs/llama-recipes

I'm encountering the same problem with 6 x H800 (80GB) for Lora 70B

I'm encountering the same problem with 8 x H100 (80GB) for Lora 70B

sorry for the late reply @yuanzhedong and everyone, is this happening only alpaca? I believe some of the issue from Sep should be resolved. We have the CPU offload now if that can be helpful.

I don't have H100s off of my hand now, but looking to get access and repro the issue.

@yuanzhedong It seems like an issue with transformers, could repro this issue with transformers version of 4.38.1 which is from pip install, installing from src could resolve the issue transformers 4.39.0.dev0

git clone https://github.com/huggingface/transformers.git
cd transformers/

pip install -e .

can you pls give it a try.

@HamidShojanazeri Thank you very much for your reply. I tried with a new install of transformers 4.39.0.dev0 from the src but still encountered the same issue. I have also attempted to decrease the batch size but to no avail. I am using a custom dataset modelled after the structure of Alpaca dataset. I will now make some modifications to the dataset and proceed with a clean installation llama-recipes and see how it goes.

sure, it shouldn't have anything to do with your batch size, the version conflict was the only way I could repro and by pass it. Pls let know how it went.

I tried with the following package versions and a fresh install of llama-recipes, but still getting the same error.

accelerate 0.28.0, appdirs 1.4.4, bitsandbytes 0.43.0, black 24.3.0, datasets 2.18.0, fire 0.6.0, gradio 4.22.0, gradio_client 0.13.0 loralib 0.1.2, matplotlib 3.8.3, matplotlib-inline 0.1.6 ,optimum 1.17.1, peft 0.9.0 ,py7zr 0.21.0, scipy 1.12.0, sentencepiece 0.2.0, torch 2.3.0+cu118, transformers 4.39.0.dev0 /home/mugheera/transformers

Below is the command that I used to start the finetuning job

torchrun --nnodes 1 --nproc_per_node 4 finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /home/mugheera/llama-hf/Llama-2-70b-chat-hf --pure_bf16 --output_dir /home/mugheera/PEFT/model --use_fast_kernels

Training Epoch: 1/3, step 46/47 completed (loss: 0.15002813935279846): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 47/47 [34:52<00:00, 44.52s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.17413900792598724): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 47/47 [35:40<00:00, 45.54s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.1314418613910675): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 47/47 [34:16<00:00, 43.75s/it]
Training Epoch: 1/3, step 46/47 completed (loss: 0.18022574484348297): 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 47/47 [35:46<00:00, 45.68s/it]
Max CUDA memory allocated was 31 GB
Max CUDA memory reserved was 39 GB
Peak active CUDA memory was 31 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 153 GB
evaluating Epoch: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 17/17 [01:08<00:00, 4.00s/it]
evaluating Epoch: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 17/17 [01:08<00:00, 4.02s/it]
evaluating Epoch: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 17/17 [01:08<00:00, 4.02s/it]
evaluating Epoch: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 17/17 [01:08<00:00, 4.02s/it]
eval_ppl=tensor(1.1495, device='cuda:0') eval_epoch_loss=tensor(0.1393, device='cuda:0')
we are about to save the PEFT modules
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/mugheera/new/llama-recipes/recipes/finetuning/finetuning.py", line 8, in
[rank2]: fire.Fire(main)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 143, in Fire
[rank2]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 477, in _Fire
[rank2]: component, remaining_args = _CallAndUpdateTrace(
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank2]: component = fn(*varargs, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/new/llama-recipes/src/llama_recipes/finetuning.py", line 265, in main
[rank2]: results = train(
[rank2]: ^^^^^^
[rank2]: File "/home/mugheera/new/llama-recipes/src/llama_recipes/utils/train_utils.py", line 187, in train
[rank2]: model.save_pretrained(train_config.output_dir)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/peft/peft_model.py", line 215, in save_pretrained
[rank2]: output_state_dict = get_peft_model_state_dict(
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/peft/utils/save_and_load.py", line 71, in get_peft_model_state_dict
[rank2]: state_dict = model.state_dict()
[rank2]: ^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1911, in state_dict
[rank2]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank2]: [Previous line repeated 2 more times]
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1907, in state_dict
[rank2]: hook(self, prefix, keep_vars)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 786, in _pre_state_dict_hook
[rank2]: _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 307, in _full_pre_state_dict_hook
[rank2]: _common_unshard_pre_state_dict_hook(
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 174, in _common_unshard_pre_state_dict_hook
[rank2]: _enter_unshard_params_ctx(
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 138, in _enter_unshard_params_ctx
[rank2]: fsdp_state._unshard_params_ctx[module].enter()
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/contextlib.py", line 137, in enter
[rank2]: return next(self.gen)
[rank2]: ^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_unshard_param_utils.py", line 196, in _unshard_fsdp_state_params
[rank2]: _unshard(state, handle, computation_stream, computation_stream)
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 299, in _unshard
[rank2]: handle.unshard()
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1307, in unshard
[rank2]: unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 1334, in _alloc_padded_unsharded_flat_param
[rank2]: _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/utils.py", line 168, in _alloc_storage
[rank2]: tensor._typed_storage().resize(size.numel())
[rank2]: File "/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/storage.py", line 972, in resize
[rank2]: self.untyped_storage.resize(size * self._element_size())
[rank2]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.60 GiB. GPU has a total capacity of 79.15 GiB of which 989.31 MiB is free. Process 4076404 has 4.03 GiB memory in use. Including non-PyTorch memory, this process has 74.00 GiB memory in use. Of the allocated memory 70.11 GiB is allocated by PyTorch, and 1.53 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management

And the following error in another iteration

evaluating Epoch: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 11/11 [00:41<00:00, 3.75s/it]
eval_ppl=tensor(2.6441, device='cuda:0') eval_epoch_loss=tensor(0.9723, device='cuda:0')
we are about to save the PEFT modulested (loss: 1.2340718507766724): 19%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–Œ | 9/48 [06:16<28:13, 43.41s/it]
/home/mugheera/miniconda3/envs/llama-recipes/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py:348: UserWarning: Failed to clone() tensor with name base_model.model.model.layers.23.mlp.up_proj.weight on rank 2. This may mean that this state_dict entry could point to invalid memory regions after returning from state_dict() call if this parameter is managed by FSDP. Please check clone implementation of base_model.model.model.layers.23.mlp.up_proj.weight. Error: CUDA out of memory. Tried to allocate 448.00 MiB. GPU has a total capacity of 79.15 GiB of which 313.31 MiB is free. Process 4076404 has 4.03 GiB memory in use. Including non-PyTorch memory, this process has 74.82 GiB memory in use. Of the allocated memory 72.29 GiB is allocated by PyTorch, and 175.13 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management

@Mugheeera I wonder if you are installing llama-recipes from src?

This seems to be working on my end, running H100, but regardless it should work on both A100 and H100. Logs

I could use the latest transfromers as well so not from src anymore,

accelerate  0.28.0, bitsandbytes   0.43.0, transformers   4.38.2, torch   2.3.0+cu121

sounds to be stale issue, will close it for now but feel free to re-open is see same issues.

Hi @HamidShojanazeri , I also got the same OOM error when using 8xH100s.
I'm using the transformer '4.41.0.dev0' and build the llama-recipes from the source.
I observed this error will happen when training with 70B model (in my case, the llama3-70b) and it does not exist for llama 3 8B.
I'm wondering the reason to use use model.save_pretrained when using both PEFT and FSDP instead of using save_model_and_optimizer_sharded as did when not using PEFT? (

if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
if train_config.use_peft:
if train_config.enable_fsdp:
if rank==0:
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
else:
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
save_model_checkpoint(
model, optimizer, rank, train_config, epoch=epoch
)
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
print("=====================================================")
save_model_and_optimizer_sharded(model, rank, train_config)
if train_config.save_optimizer:
save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
print("=====================================================")
)
Would it possible the reason of OOM is because the model needs to gather weights across ranks before model.save_pretrained?