DoRA uses lots of GPU VRAM due to fp32 upcasting
rationalism opened this issue · 6 comments
System Info
peft 0.10.0, transformers 4.40.1, Python 3.10 on Ubuntu 22.04
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
Doing language model fine-tuning using QLoRA with DoRA, eg. fine-tuning Meta-Llama-8-70B with https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py, with target_modules set to include all linear layers, uses much more GPU VRAM than training with ordinary LoRA.
Expected behavior
Fine-tuning a QLoRA language model using DoRA, with adapters applied to all linear layers, takes up much more GPU VRAM than ordinary LoRA and OOMed my machine. I think the issue is this line:
peft/src/peft/tuners/lora/layer.py
Line 238 in 608a90d
it looks like magnitude is in fp32, so the input vector x is upcast to fp32 when it gets returned as result_dora. If both MLP and attention layers are added to target_modules, that fp32 output then causes the next DoRA module (in the MLP layer) to get an fp32 vector as input. This then causes the dequantized weight matrix to get upcast to fp32:
peft/src/peft/tuners/lora/layer.py
Line 229 in 608a90d
which means the algebra in _get_weight_norm is done in fp32:
peft/src/peft/tuners/lora/layer.py
Line 176 in 608a90d
which OOMs my machine. Adding a cast back to x.dtype here:
peft/src/peft/tuners/lora/layer.py
Line 253 in 608a90d
fixes the problem. (I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes)
Thanks for investigating this issue. Could you provide a bit more information so that we can reproduce this issue (ideally with a smaller model, so that we can create a unit test based on this). For instance, what dtype do you use, bf16?
it looks like magnitude is in fp32, so the input vector x is upcast to fp32
What was the initial dtype for you?
I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes
If you want to consider contributing this to PEFT, let us know.
@BenjaminBossan Thanks! I'm using bf16. Any broadly Llama-like model should work, eg. I've run tests with Meta-Llama-3-8B and Mistral-7B-v0.1 (the original, non-MoE Mistral).
Here's the Triton code I wrote. When I benchmarked it, it reduced total training time by 20% and reduced GPU VRAM consumption by 1,974 MB (this is comparing it to the PyTorch implementation of _get_weight_norm()
after I already added return result_dora.to(x.dtype)
at the end of _apply_dora()
). Settings were fine-tuning Llama-3-70B with QLoRA, 2x 4090 GPUs, BNB 4-bit quantization with nf4 data type, compute_dtype bfloat16, double quantization enabled, LoRA rank r = 12, target modules set to ["q_proj", "v_proj", "o_proj", "k_proj", "gate_proj", "up_proj", "down_proj"] (all linear layers), learning rate 1e-4, batch size = 1, gradient accumulation = 4, sequence length = 1024.
(this is my first time writing Triton code, so I'm sure I did something dumb somewhere, but the tests pass and training seems stable)
import torch
import triton
import triton.language as tl
# tunable
BLOCK_SIZE_N = 64
@triton.autotune(configs=[
triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=4),
triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=2, num_stages=5),
triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=1, num_stages=6),
triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=4),
triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=2, num_stages=5),
triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=1, num_stages=6),
triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=4),
triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=3),
triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=4),
triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=5),
],
key=['M', 'N', 'K']
)
@triton.jit
def weight_norm_kernel(
weight_ptr, lora_B_ptr, lora_A_ptr, output_ptr,
scaling, M, N, K,
stride_wm, stride_wn, stride_bm, stride_bn, stride_am, stride_an, stride_om, stride_on,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
pid_k = tl.program_id(2)
offm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offm < M
mask_n = offn < N
mask_k = offk < K
weight_ptrs = weight_ptr + offm[:, None] * stride_wm + offn[None, :] * stride_wn
lora_B_ptrs = lora_B_ptr + offm[:, None] * stride_bm + offk[None, :] * stride_bn
lora_A_ptrs = lora_A_ptr + offk[:, None] * stride_am + offn[None, :] * stride_an
weight = tl.load(weight_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0)
lora_B = tl.load(lora_B_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
lora_A = tl.load(lora_A_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)
lora_weight_row = tl.dot(lora_B, lora_A)
weight += scaling * lora_weight_row
squared_sum = tl.sum(weight * weight, axis=1)
output_ptrs = output_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) * stride_om + pid_n
tl.store(output_ptrs, squared_sum, mask=mask_m)
def get_weight_norm_triton(weight, lora_B, lora_A, scaling):
M, N = weight.shape
K = lora_B.shape[1]
output = torch.empty((M, triton.cdiv(N, BLOCK_SIZE_N)), device=weight.device, dtype=torch.float32)
# minimum allowed by tl.dot()
BLOCK_SIZE_K = max(16, triton.next_power_of_2(K))
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']),
triton.cdiv(N, BLOCK_SIZE_N),
triton.cdiv(K, BLOCK_SIZE_K),
)
weight_norm_kernel[grid](
weight, lora_B, lora_A, output,
scaling, M, N, K,
weight.stride(0), weight.stride(1),
lora_B.stride(0), lora_B.stride(1),
lora_A.stride(0), lora_A.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K
)
return torch.sqrt(torch.sum(output, dim=1)).to(weight.dtype)
def test_weight_norm(weight, lora_b, lora_a):
torch.manual_seed(42)
scaling = 0.1
# Expected result using PyTorch
lora_weight = lora_b.to(torch.float32) @ lora_a.to(torch.float32)
expected_norm = torch.linalg.norm(weight.to(torch.float32) + scaling * lora_weight, dim=1).to(weight.dtype)
# Triton result
triton_norm = get_weight_norm_triton(weight, lora_b, lora_a, scaling)
assert torch.allclose(expected_norm, triton_norm, atol=1e-5), "Test failed!"
print("Test passed with all norms close enough!")
def run_tests():
M, N, K = 8192, 28672, 12
weight = torch.ones(M, N, dtype=torch.bfloat16, device='cuda')
lora_b = torch.ones(M, K, dtype=torch.bfloat16, device='cuda')
lora_a = torch.ones(K, N, dtype=torch.bfloat16, device='cuda')
test_weight_norm(weight, lora_b, lora_a)
weight = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
lora_b = torch.zeros(M, K, dtype=torch.bfloat16, device='cuda')
lora_a = torch.zeros(K, N, dtype=torch.bfloat16, device='cuda')
test_weight_norm(weight, lora_b, lora_a)
weight = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
lora_b = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
lora_a = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
test_weight_norm(weight, lora_b, lora_a)
Thanks! I'm using bf16. Any broadly Llama-like model should work, eg. I've run tests with Meta-Llama-3-8B and Mistral-7B-v0.1 (the original, non-MoE Mistral).
Thanks for the info. I still couldn't quite replicate the issue, would it be possible for you to share a code snippet? We already have some code that should take care of setting the right dtype here:
peft/src/peft/tuners/lora/bnb.py
Lines 468 to 478 in 77b7238
But apparently, you're not hitting that code. You mention QLoRa but the code you quote is from the normal LoRA layers, so I'm a bit confused.
Here's the Triton code I wrote. When I benchmarked it, it reduced total training time by 20% and reduced GPU VRAM consumption by 1,974 MB
That sounds pretty sweet. I have never worked with triton, so I'm not sure if we can just add this code and it'll work for everyone or if more work is necessary. If you're willing to contribute this but are unsure yourself, maybe I can find someone at HF with more expertise to take a look.
@rationalism gentle ping, do you have updates?