/Quest

[ICML 2024] Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

Primary LanguageCuda

Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

[paper] [poster(draft)]

TL;DR

Quest is an efficient long-context LLM inference framework that leverages query-aware sparsity in KV cache to reduce memory movement during attention and thus boost throughput.

Abstract

As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128k or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query.

To this end, we propose Quest, a query-aware token criticality estimation algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 7.03× self-attention speedup, which reduces inference latency by 2.23× while performing well on tasks with long dependencies with negligible accuracy loss.

Installation

  1. Clone this repo (also clone submodules)
git clone --recurse-submodules https://github.com/mit-han-lab/quest
cd quest
  1. Install dependency libraries
conda create -yn quest python=3.10
conda activate quest
pip install -e . && pip install flash-attn==2.3.0 --no-build-isolation

# Install CMake (with version >= 3.26.4)
conda install cmake

# build libraft
cd kernels/3rdparty/raft
./build.sh libraft
  1. Compile kernel benchmarks (Optional)
cd kernels
mkdir build && cd build
cmake ..
make -j
  1. Build end-to-end operators with PyBind
# This will automatically build and link the operators
cd quest/ops
bash setup.sh

Accuracy Evaluation

Our evaluations are based on LongChat-7B-v1.5-32K and Yarn-Llama2-7B-128K models, which are capable of handling long-context text generations. We evaluate both passkey retrieval and LongBench benchmarks. We provide several scripts to reproduce our results in the paper:

To get the Passkey Retrieval results, please modify and execute:

bash scripts/passkey.sh

To reproduce the LongBench results, please modify and execute:

bash scripts/longbench.sh

To evaluate the perplexity result of PG-19, please execute:

bash scripts/ppl_eval.sh

Efficiency Evaluation

Kernels and end-to-end effiency are evaluated on NVIDIA Ada6000 and RTX4090 GPUs with CUDA version of 12.4. We provide several scripts to reproduce our results in the paper:

Kernel-level Efficiency

We also release the unit tests and benchmarks used for kernel implementations. Correctness of kernel is verified by unit tests in kernels/src/test, while performance is evaluated by NVBench in kernels/src/bench. We also test the correctness of PyBind operators in quest/tests with PyTorch results via PyTest.

To test the correctness of kernels, please execute:

cd kernels/build
./test_batch_decode # or any other operator

Or utilize PyTest:

cd quest/tests
PYTHONPATH=$PYTHONPATH:../../ pytest

To reproduce the kernel performance shown in paper, please execute:

cd kernels/build
./bench_batch_decode -a seqlen=4096 -a page_budget=[64,512]
# or any other operator

With sample output:

End-to-end Efficiency

Quest can achieve up to 2.23× end-to-end speedup while performing well on tasks with long dependencies with negligible accuracy loss:

We incorporate all implemented operators into a full pipeline to evaluate the end-to-end efficiency in text generations. Based on the Huggingface Transformers, we enable a KV-Cache manager which supports query-aware sparsity as shown in quest/models/QuestAttention.py.

To reproduce the end-to-end efficiency results in Figure.10, please execute:

bash scripts/bench_efficiency_e2e.sh

For the qualitative analysis of baselines, we use FlashInfer kernel to estimate the performance of H2O and TOVA. To reproduce the results in Figure.11, please execute:

bash scripts/bench_kernels.sh

Examples

We provide several examples to demonstrate the usage of Quest. These examples are implemented with the end-to-end integration of Quest operators, and can be executed with the following commands (please make sure you have setup all the operators):

python3 scripts/example_textgen.py

With example output of long-context summarization under LongChat-7B-v1.5-32K model:

You can also try scripts/example_demo.py to test the performance of Quest on your own text generation tasks. We provide a simple interface to load the model and generate text with Quest operators. The above demo is an example with 32K input on FP16 LongChat-7B-v1.5-32K. Quest with 2048 token budget achieves 1.7x speedup compared to full cache FlashInfer version.

TODOs

  • Support GQA models

Reference

If you find this project is helpful to your research, please consider to cite our paper:

@misc{tang2024quest,
      title={Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference}, 
      author={Jiaming Tang and Yilong Zhao and Kan Zhu and Guangxuan Xiao and Baris Kasikci and Song Han},
      year={2024},
      eprint={2406.10774},
      archivePrefix={arXiv},
      primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}

Related Projects

This codebase utilizes lm_eval to evaluate perplexity and zero-shot accuracy. It also adapts code snippets from H2O, StreamingLLM and Punica. Our kernels are implemented based on FlashInfer (a performant and extensible kernel library for LLM serving) and tested by NVBench. Thanks for the great works from our community!

H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models

TOVA: Transformers are Multi-State RNNs

StreamingLLM: Efficient Streaming Language Models with Attention Sinks

AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration