eric-mitchell/direct-preference-optimization

Running into CUDA out of Memory: DPO Pipeline for Custom Llama Model (FSDP Trainer)

yash7verma opened this issue · 4 comments

I have access to two GPUs on my machine that are Quadro RTX 8000 with RAM of 45 GB each. I am trying to run the dpo pipeline for a custom model (Vicuna Model which is Llama Model with Vicuna weights). Somehow on my machine I get 'nan' output in metrics if I use policy dtype as float16, so I am constrained to using dtype=float32. I am using FSDP trainer and have reduced batch sizes to 2 (bare minimum). Since the reference model is not required to be trained and is only used for inference, I have offloaded the model to CPU using accelerate's infer_auto_device_map and dispatch_model. Also I had to keep reference dtype=float32 since other wise I run into
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'.

Hence in order to save as much space as possible I have avoided using the concatenated_inputs which tried to reduce redundant step of calling the FSDP model twice and save some time. I have tried to internally overwrite that step and called the model twice to save as much GPu space as possible. Yet I am running into CUDA out of memory.

My question is, Is this step really reducing the GPU memory usage, please help me out. I am pasting screen shot of the code change made.
Screenshot 2023-07-17 at 9 10 34 PM
Screenshot 2023-07-17 at 9 11 59 PM

Somehow on my machine I get 'nan' output in metrics if I use policy dtype as float16

This is not unusual, since the range of float16 is much smaller than float32, and overflow is common. You might have better luck with bfloat16.

since other wise I run into RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'.

This is also expected, since all ops do not have half precision CPU implementations. This is a pytorch limitation and unrelated to DPO.

To get to your actual question:

Is this step really reducing the GPU memory usage, please help me out. I am pasting screen shot of the code change made.

I wouldn't expect de-concatenating the inputs to the policy to significantly reduce GPU memory usage, since either way, we need to store the whole computation graph for both the chosen and rejected samples so we can do the backward pass. Concatenating them and doing a single FSDP call just seems to be a bit faster (potentially reducing some communication) from FSDP.

If you're fine-tuning a llama 7b-sized model, fp32 is just going to require 7B * 4 bytes/param * 3 = 84G VRAM just to store the model, grads, optimizer state, which doesn't really leave any room for the activations/CUDA contexts if you've got 2 45G cards. So your options are:

  • drop the precision (try bfloat16)
  • use parameter-efficient fine-tuning (PEFT), like LoRA
  • use gradient/activation checkpointing
    I'm hoping we'll be able to add some PEFT code soon, but the other choices are already supported.

@yash7verma just checking in- any progress on this issue?

@yash7verma I am happy to assist with your problem- let's close this one, and you can open a new issue for your current blocker?