OOM with Phi-3-mini (3.8B) on 83.5GB RAM due to LoftQ
adamamer20 opened this issue · 4 comments
System Info
System: Linux-6.1.58+-x86_64-with-glibc2.35 / Google Colab
peft: 0.10.0
transformers: 4.40.1
accelerate: 0.30.0
Python 3.10.12
RAM: 83.5 GB
GPU: A100 40GB
CPU: Intel(R) Xeon(R) CPU @ 2.20GHz
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
from peft import LoraConfig, LoftQConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
# checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=th.bfloat16,
device_map="auto",
)
peft_config = LoraConfig(
r = 8,
lora_alpha = 32,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
target_modules = "all-linear",
modules_to_save = None,
loftq_config = LoftQConfig(loftq_bits=8),
init_lora_weights="loftq",
use_rslora = True,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
model = prepare_model_for_kbit_training(model).to("cpu")
th.cuda.empty_cache()
gc.collect()
Until here everything is fine. The model size is about 7GB and it's loaded onto RAM.
However when I try to get the PEFT model :
model = get_peft_model(model, peft_config)
This leads to a crash of the entire system, despite having plenty of space.
Note that when I remove LoftQ, the problem does not occur:
peft_config = LoraConfig(
r = 8,
lora_alpha = 32,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
target_modules = "all-linear",
modules_to_save = None,
use_rslora = True,
)
Expected behavior
The model should comfortably fit in the RAM.
Thanks for reporting. I dug a bit deeper. The offending line, at least in my setup, is:
peft/src/peft/utils/loftq_utils.py
Line 140 in 32f3878
With the incoming weight having a shape of (3072, 3072), we have:
weight_divabs
=> 147456, 64, 1L_reshaped
=> 1, 256abs_diff
=> 147456, 64, 256
So abs_diff
tries to allocate 9 GB of memory (with float32). I wonder if we can avoid such a huge shape. Pinging @yxli2123.
What you could try right now is to use the replace_lora_weights_loftq
function. This allows you to load the model with bnb quantized weights, i.e. with lower memory requirement, and apply LoftQ on the fly with relatively little overhead. I tried this on my machine and memory was consistently < 5GB:
import time
import gc
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, replace_lora_weights_loftq
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
peft_config = LoraConfig(
r = 8,
lora_alpha = 32,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
target_modules = "all-linear",
modules_to_save = None,
use_rslora = True,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
model = prepare_model_for_kbit_training(model)
#model = model.to("cpu")
torch.cuda.empty_cache()
gc.collect()
model = get_peft_model(model, peft_config)
replace_lora_weights_loftq(model) # takes a couple of minutes
Note that using this approach is more memory efficient, but it might not perform as well, at least not without making use of the callback feature described in this LoftQ init notebook.
@BenjaminBossan Thank you for the advice, your method works! So the issue was that the weight matrix (3072,3072) was being quantized all at once and there wasn't enough space available for the necessary computations.
Can you clarify however what does replace_lora_weights_loftq
do? Because from the source code it seems to assume that the Lora adapter weights are quantized already, but there's no mention of quantization in your LoraConfig
. Are the Lora weights initialized as quantized because the model weights are quantized?
Can you clarify however what does
replace_lora_weights_loftq
do? Because from the source code it seems to assume that the Lora adapter weights are quantized already, but there's no mention of quantization in yourLoraConfig
. Are the Lora weights initialized as quantized because the model weights are quantized?
The LoRA weights are never quantized, regardless of whether the base model is quantized or not. This is necessary because quantized weights cannot be trained, and we want the LoRA weights to be trained. But since the total number of parameters of the LoRA weights is typically small, this should still result in less memory being used than full fine-tuning.
You're right, I forgot that quantization can be used only during inference. Thank you very much.