huggingface/peft

DoRA uses lots of GPU VRAM due to fp32 upcasting

rationalism opened this issue · 6 comments

System Info

peft 0.10.0, transformers 4.40.1, Python 3.10 on Ubuntu 22.04

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Doing language model fine-tuning using QLoRA with DoRA, eg. fine-tuning Meta-Llama-8-70B with https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py, with target_modules set to include all linear layers, uses much more GPU VRAM than training with ordinary LoRA.

Expected behavior

Fine-tuning a QLoRA language model using DoRA, with adapters applied to all linear layers, takes up much more GPU VRAM than ordinary LoRA and OOMed my machine. I think the issue is this line:

mag_norm_scale = (magnitude / weight_norm).view(1, -1)

it looks like magnitude is in fp32, so the input vector x is upcast to fp32 when it gets returned as result_dora. If both MLP and attention layers are added to target_modules, that fp32 output then causes the next DoRA module (in the MLP layer) to get an fp32 vector as input. This then causes the dequantized weight matrix to get upcast to fp32:

weight = weight.to(x.dtype)

which means the algebra in _get_weight_norm is done in fp32:

def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:

which OOMs my machine. Adding a cast back to x.dtype here:

return result_dora

fixes the problem. (I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes)

Thanks for investigating this issue. Could you provide a bit more information so that we can reproduce this issue (ideally with a smaller model, so that we can create a unit test based on this). For instance, what dtype do you use, bf16?

it looks like magnitude is in fp32, so the input vector x is upcast to fp32

What was the initial dtype for you?

I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes

If you want to consider contributing this to PEFT, let us know.

@BenjaminBossan Thanks! I'm using bf16. Any broadly Llama-like model should work, eg. I've run tests with Meta-Llama-3-8B and Mistral-7B-v0.1 (the original, non-MoE Mistral).

Here's the Triton code I wrote. When I benchmarked it, it reduced total training time by 20% and reduced GPU VRAM consumption by 1,974 MB (this is comparing it to the PyTorch implementation of _get_weight_norm() after I already added return result_dora.to(x.dtype) at the end of _apply_dora()). Settings were fine-tuning Llama-3-70B with QLoRA, 2x 4090 GPUs, BNB 4-bit quantization with nf4 data type, compute_dtype bfloat16, double quantization enabled, LoRA rank r = 12, target modules set to ["q_proj", "v_proj", "o_proj", "k_proj", "gate_proj", "up_proj", "down_proj"] (all linear layers), learning rate 1e-4, batch size = 1, gradient accumulation = 4, sequence length = 1024.

(this is my first time writing Triton code, so I'm sure I did something dumb somewhere, but the tests pass and training seems stable)

import torch
import triton
import triton.language as tl

# tunable
BLOCK_SIZE_N = 64

@triton.autotune(configs=[
    triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=4, num_stages=4),
    triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=2, num_stages=5),
    triton.Config(kwargs={'BLOCK_SIZE_M': 32}, num_warps=1, num_stages=6),
    triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=4, num_stages=4),
    triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=2, num_stages=5),
    triton.Config(kwargs={'BLOCK_SIZE_M': 64}, num_warps=1, num_stages=6),
    triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=4, num_stages=4),
    triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=3),
    triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=4),
    triton.Config(kwargs={'BLOCK_SIZE_M': 128}, num_warps=2, num_stages=5),
  ],
  key=['M', 'N', 'K']
)
@triton.jit
def weight_norm_kernel(
    weight_ptr, lora_B_ptr, lora_A_ptr, output_ptr,
    scaling, M, N, K,
    stride_wm, stride_wn, stride_bm, stride_bn, stride_am, stride_an, stride_om, stride_on,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    pid_k = tl.program_id(2)
    offm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
    mask_m = offm < M
    mask_n = offn < N
    mask_k = offk < K

    weight_ptrs = weight_ptr + offm[:, None] * stride_wm + offn[None, :] * stride_wn
    lora_B_ptrs = lora_B_ptr + offm[:, None] * stride_bm + offk[None, :] * stride_bn
    lora_A_ptrs = lora_A_ptr + offk[:, None] * stride_am + offn[None, :] * stride_an

    weight = tl.load(weight_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0)
    lora_B = tl.load(lora_B_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0)
    lora_A = tl.load(lora_A_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0.0)

    lora_weight_row = tl.dot(lora_B, lora_A)
    weight += scaling * lora_weight_row
    squared_sum = tl.sum(weight * weight, axis=1)

    output_ptrs = output_ptr + (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) * stride_om + pid_n
    tl.store(output_ptrs, squared_sum, mask=mask_m)


def get_weight_norm_triton(weight, lora_B, lora_A, scaling):
    M, N = weight.shape
    K = lora_B.shape[1]
    
    output = torch.empty((M, triton.cdiv(N, BLOCK_SIZE_N)), device=weight.device, dtype=torch.float32)

    # minimum allowed by tl.dot()
    BLOCK_SIZE_K = max(16, triton.next_power_of_2(K))
    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_SIZE_M']),
        triton.cdiv(N, BLOCK_SIZE_N),
        triton.cdiv(K, BLOCK_SIZE_K),
    )

    weight_norm_kernel[grid](
        weight, lora_B, lora_A, output,
        scaling, M, N, K,
        weight.stride(0), weight.stride(1),
        lora_B.stride(0), lora_B.stride(1),
        lora_A.stride(0), lora_A.stride(1),
        output.stride(0), output.stride(1),
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K
    )

    return torch.sqrt(torch.sum(output, dim=1)).to(weight.dtype)



def test_weight_norm(weight, lora_b, lora_a):
    torch.manual_seed(42)
    scaling = 0.1

    # Expected result using PyTorch
    lora_weight = lora_b.to(torch.float32) @ lora_a.to(torch.float32)
    expected_norm = torch.linalg.norm(weight.to(torch.float32) + scaling * lora_weight, dim=1).to(weight.dtype)

    # Triton result
    triton_norm = get_weight_norm_triton(weight, lora_b, lora_a, scaling)

    assert torch.allclose(expected_norm, triton_norm, atol=1e-5), "Test failed!"

    print("Test passed with all norms close enough!")


def run_tests():
    M, N, K = 8192, 28672, 12

    weight = torch.ones(M, N, dtype=torch.bfloat16, device='cuda')
    lora_b = torch.ones(M, K, dtype=torch.bfloat16, device='cuda')
    lora_a = torch.ones(K, N, dtype=torch.bfloat16, device='cuda')
    test_weight_norm(weight, lora_b, lora_a)

    weight = torch.zeros(M, N, dtype=torch.bfloat16, device='cuda')
    lora_b = torch.zeros(M, K, dtype=torch.bfloat16, device='cuda')
    lora_a = torch.zeros(K, N, dtype=torch.bfloat16, device='cuda')
    test_weight_norm(weight, lora_b, lora_a)

    weight = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
    lora_b = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
    lora_a = torch.randn(K, N, dtype=torch.bfloat16, device='cuda')
    test_weight_norm(weight, lora_b, lora_a)

Thanks! I'm using bf16. Any broadly Llama-like model should work, eg. I've run tests with Meta-Llama-3-8B and Mistral-7B-v0.1 (the original, non-MoE Mistral).

Thanks for the info. I still couldn't quite replicate the issue, would it be possible for you to share a code snippet? We already have some code that should take care of setting the right dtype here:

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
if requires_conversion:
output = output.to(expected_dtype)

But apparently, you're not hitting that code. You mention QLoRa but the code you quote is from the normal LoRA layers, so I'm a bit confused.

Here's the Triton code I wrote. When I benchmarked it, it reduced total training time by 20% and reduced GPU VRAM consumption by 1,974 MB

That sounds pretty sweet. I have never worked with triton, so I'm not sure if we can just add this code and it'll work for everyone or if more work is necessary. If you're willing to contribute this but are unsure yourself, maybe I can find someone at HF with more expertise to take a look.

@rationalism gentle ping, do you have updates?