microsoft/LoRA

matmul ordering in MergedLinear

zhiqi-0 opened this issue · 1 comments

Hi Team, thanks for the great work!

I'm trying to understand the benefit in LoRA but get confused when reading the forward implementation code of MergedLinear.

I'm curious why the matmul ordering of x @ (lora_A @ lora_B) is inconsistent with the implementation of Linear x @ lora_A @ lora_B. It seems the latter one will be more efficient in computation and saves more memory during training.

To my understanding, during training, following MergedLinear, PyTorch will additionally save a same-shape activation (lora_A @ lora_B) with the shape of original weight, while the implementation of Linear will only save [batch_size, seqlen, r] shape, which should be far smaller than the weight?

For computation, Linear only needs 2 x B x Seqlen x d x r FLOPs while MergedLinear requires d x r x d + B x Seqlen x d x d. (d refers to hidden size, B refers to batch size).

Is there any reason to do this?

I agree that it's inefficient. If possible, one should replace individual layers.

MergedLinear provides a way to apply LoRA to just q_proj and v_proj when qkv are combined to a single Linear layer as in the GPT-2 codebase. Grouping lora_A and lora_B first and using grouped convolution made it easier to code up, but it is indeed inefficient.