matmul ordering in MergedLinear
zhiqi-0 opened this issue · 1 comments
Hi Team, thanks for the great work!
I'm trying to understand the benefit in LoRA but get confused when reading the forward implementation code of MergedLinear.
I'm curious why the matmul ordering of x @ (lora_A @ lora_B)
is inconsistent with the implementation of Linear x @ lora_A @ lora_B
. It seems the latter one will be more efficient in computation and saves more memory during training.
To my understanding, during training, following MergedLinear
, PyTorch will additionally save a same-shape activation (lora_A @ lora_B
) with the shape of original weight, while the implementation of Linear
will only save [batch_size, seqlen, r]
shape, which should be far smaller than the weight?
For computation, Linear
only needs 2 x B x Seqlen x d x r
FLOPs while MergedLinear
requires d x r x d + B x Seqlen x d x d
. (d refers to hidden size, B refers to batch size).
Is there any reason to do this?
I agree that it's inefficient. If possible, one should replace individual layers.
MergedLinear provides a way to apply LoRA to just q_proj and v_proj when qkv are combined to a single Linear layer as in the GPT-2 codebase. Grouping lora_A and lora_B first and using grouped convolution made it easier to code up, but it is indeed inefficient.