This codebase reimplementes LoRA: Low-Rank Adaptation of Large Language Models (ICLR 2022) and is reconstructed based on loralib.
The implementations of loratorch
and loralib
are very different. We take the nn.Linear
as an example as follows.
- For
loralib
,$h = x W_0^\top + \frac{\alpha}{r} x(BA)^\top,$
where
- For
loratorch
,$h = x (W_0 + \frac{\alpha}{r} BA)^\top.$
loralib
computes loratorch
merges pre-trained weight nn.Linear.forward()
. There is no difference between loralib
and loratorch
in the linear layers. But in some no-linear or complex layers, we are no sure whether this layer satisfies loralib
. On the contrary, the idea of merging weights first in loratorch
is more general and extensible. You just call merge_lora_param()
in loratorch
to merge weights and then call forward()
in the original layer to compute the results. With the help of loratorch
, you can easily implement LoRA to any type of layer of torch.nn
.
loralib |
loratorch |
||
---|---|---|---|
nn.Linear |
✓ | ✓ | linear.ipynb |
nn.Embedding |
✓ | ✓ | embedding.ipynb |
nn.Conv1d |
✓ | ✓ | |
nn.Conv2d |
✓ | ✓ | |
nn.Conv3d |
✓ | ✓ | |
nn.MultiheadAttention |
✘ | ✓ | |
MergedLinear |
✓ (Error) | ✓ | mergedlinear.ipynb |
hard to extend | easy to extend |
We compare the results of loralib
and loratorch
in examples to demonstrate the correctness of the implementation in loratorch
.
The usage of loratorch
is the same as loralib
.
-
Install
loratorch
.pip install git+https://github.com/Baijiong-Lin/LoRA-Torch # Alternatively for developers # git clone https://github.com/Baijiong-Lin/LoRA-Torch # cd LoRA-Torch # pip install -e .
-
Replace the layers where you would like to use LoRA by using
loratorch
.# ===== Before ===== # layer = nn.Linear(in_features, out_features) # ===== After ====== import loratorch as lora # Add a pair of low-rank adaptation matrices with rank r=16 and alpha=32 layer = lora.Linear(in_features, out_features, r=16, lora_alpha=32)
-
Mark only LoRA parameters as trainable before the training loop.
model = Model() # (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names lora.mark_only_lora_as_trainable(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) # Training loop for batch in dataloader: model.train() # forward process loss = forward_fun(model, batch) # backward process optimizer.zero_grad() loss.backward() optimizer.step() # (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters() # (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters() lora.register_model_param_after_backward(model)
-
Save LoRA model (only the LoRA matrixes will be saved).
# ===== Before ===== # torch.save(model.state_dict(), checkpoint_path) # ===== After ===== torch.save(lora.lora_state_dict(model), checkpoint_path)
-
Load LoRA model (need to load the pre-trained model first).
# Load the pre-trained checkpoint first model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False) # Then load the LoRA checkpoint model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
loratorch
is developed and maintained by Baijiong Lin.
If you have any question or suggestion, please feel free to contact us by raising an issue or sending an email to bj.lin.email@gmail.com
.
loratorch
is heavily based on loralib
. We thank its authors for their wonderful and open-source codebase.
If you find loratorch
useful for your research or development, please cite the following:
@inproceedings{hu2022lora,
title={Lo{RA}: Low-Rank Adaptation of Large Language Models},
author={Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen},
booktitle={International Conference on Learning Representations},
year={2022},
}
@software{lin2023loratorch,
author = {Baijiong Lin},
title = {{LoRA-Torch}: {PyTorch} Reimplementation of {LoRA}},
url = {https://github.com/Baijiong-Lin/LoRA-Torch},
year = {2023}
}
loratorch
is released under the MIT license.