LongHeads: Multi-Head Attention is Secretly a Long Context Processor [paper]
- [2024/03/25] We successfully offload the KV cache to CPU during inference. This implementation significantly reduces memory usage and enables LLaMA-2-7b to achieve 100% accuracy with 128k context on passkey retrieval task!!!
- [2024/03/25] We release the evaluation code for the passkey retrieval task.
- [2024/03/24] We release the example code for LongHeads.
LongHeads is a training-free framework for extending the context window of large language models (LLMs) to more than 32x times their original pre-training length. LongHeads works efficiently in linear time, fits seamlessly with many LLMs that use relative positional encoding and can be integrated with popular extrapolation methods such as Positional Interpolation (PI) and NTK-Dynamic RoPE.
pip install -r requirements.txt
# We use flash-attn==2.3.6
pip install flash-attn --no-build-isolation (FlashAttention >= 2.3.6)
# load longheads model
from modeling_longheads import LlamaForCausalLM
longheads_config = {
# chunk size setting for longheads
'window_size':256,
# the attention window length of longheads (atten_length should be smaller to model's pretrained length)
'atten_length':4096,
# during encoding phrase, we use this praram to begin streamingly encoding long context with chunk selection strategy
'begin_selective_length':4096,
# whether offload KV cache to cpu memory, if True longheads can generate to 128k+ context length
'cpu_offload':False,
# whether use batch_encoding for encoding phrase acceleration, if True more memory will be needed
'batch_encoding':False,
# the hyper param for batch encoding
'encoding_batch_size':128,
}
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, **longheads_config)
python example.py
We successfully extend LLaMA-2-7b to 128k with LongHeads without additional training and achieve 100% accuracy with 128k context on passkey retrieval task! After offloading the KV cache to CPU, peak GPU memory usage is 26.51GB and 44.48 GB when inference with 64k and 128k context.
bash /passkey_retrieval/passkey_retrieval_script.sh
We will release the code in the following order, please stay tuned!
- Release core code of LongHeads.
- Release example code for usage.
- Release passkey retrieval evaluation code.
- Release code of LongHeads with efficient implementation.
If you find LongHeads useful or relevant to your project and research, please kindly cite our paper:
@misc{lu2024longheads,
title={LongHeads: Multi-Head Attention is Secretly a Long Context Processor},
author={Yi Lu and Xin Zhou and Wei He and Jun Zhao and Tao Ji and Tao Gui and Qi Zhang and Xuanjing Huang},
year={2024},
eprint={2402.10685},
archivePrefix={arXiv},
primaryClass={cs.CL}
}