The code base for project Layer-Condensed KV Cache, a new variant of transformer decoders in which queries of all layers are paired with keys and values of just the top layer. It reduces the memory and computation cost, reduces the number of parameters, significantly improves the inference throughput with comparable or better task performance. The paper "Layer-Condensed KV Cache for Efficient Inference of Large Language Models" was accepted to ACL 2024 main conference.
This work is inspired by Probabilistic Transformer, where we consider the stacking layer structure of a transformer as an iterative process of improving token representation.
You may install the dependencies with the following commands:
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install xformers --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
where the CUDA version is set to 12.1
. For other CUDA versions, please refer to installation instructions of PyTorch and xFormers. See Trouble shooting for more details.
Our implementation is based on HuggingFace transformers
where we register a new model opt-llama
that supports the Layer-Condensed KV Cache.
import models # register the opt-llama model
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
model = AutoModelForCausalLM.from_config(config="configs/tinyllama_opt.json")
and now you have a randomly initialized model with the Layer-Condensed KV Cache.
We follows all the acceleration tricks in tinyllama, with the minimal modification to the huggingface transformers code. So we may train the model with huggingface trainer with the training speed comparable to the original tinyllama code.
To enable the optimization, add the following environment variable before running the training script:
# improvement: huge
export LCKV_FLASH_ATTN=1
# improvement: significant
export LCKV_FUSED_RMSNORM=1
# improvement: none
export LCKV_FUSED_CROSSENTROPY=1
# improvement: none
export LCKV_FUSED_ROTARY=1
# improvement: slightly
export LCKV_FUSED_SWIGLU=1
We've done this for you in the provided training scripts. You may also refer to my tinyllama repo for a pure PyTorch implementation for the Llama model.
We provide some sample configuration files in the configs
folder. The config settings are defined in models/configuration_llama.py. You may refer to this file for more details.
Notice that some of the settings have different names and meanings compared to that in our paper. The following figure explains the correspondence:
We use the same training script as the original transformers
library. You may refer to the official documentation for more details.
We provide a training script run_clm.sh
for training a 50M parameter model on the wikitext-103
dataset. You may run the script with:
bash run_clm.sh
See the script for more details. For pretraining on SlimPajama, please follow the instructions in tinyllama-zh and replace the dataset with SlimPajama.
We use the same inference script as the original transformers
library. To perform inference, you may run the following command:
bash run_generation.sh
See the script for more details.
We integrate our model with StreamingLLM. To perform streaming inference, you may run the following command:
bash run_streaming.sh
See the script for more details. The codes follow the official implementation with minimal modification.
Warning
The script run_streaming.py
is not supported yet.
We use LM-Harness to evaluate the model. You may run the following command:
python test_harness.py
Change the model_args
and tasks
in the script to evaluate different models and datasets.
To test the latency of the model, you may run the following command:
python test_latency.py
Behavior:
Runtime error.
ImportError: /home/.../flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_...
Solution:
pip uninstall flash-attn
FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn
The cuda version may affect the installation of:
Please make sure to install the correct version of the packages (so long as they are consistent, the code would work). Also make sure that nvcc
is installed and available in the path.
Our experiment environment uses CUDA 12.1
and you may install with
conda install pytorch==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt