/chai

CHAI is a library for dynamic pruning of attention heads for efficient LLM inference.

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

CHAI

CHAI is an inference time pruning method which clusters attention heads that have similar output together with dynamic determination of clusters. Details can be found in our paper: CHAI: Clustered Head Attention for Efficient LLM Inference (Agarwal et al, 2024).

This repository as intended as a reference implementation for implementing CHAI. To run CHAI please download [LLaMa]((https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) (arXiv)) models and run inference.

You can download the checkpoints and tokenizer, fill this google form. The repository follows the same code base as LLama-v1.

Setup

Apply CHAI patch to the Llama model:

cd llama
git apply ../llama.patch
cd ..

In a conda env with pytorch / cuda available, run:

pip install -r requirements.txt

Then in this repository:

pip install -e .

Inference

torchrun --nproc_per_node 1 example_chai.py --ckpt_dir <ModelFolder> --tokenizer_path <tokenizer.model>

Implementation Details

CHAI is implemented primarily in the Forward Function, in model.py for attention.

 

Citation

CHAI is accepted at ICML'24. Please cite as:

@inproceedings{
agarwal2024chai,
title={{CHAI}: Clustered Head Attention for Efficient {LLM} Inference},
author={Saurabh Agarwal and Bilge Acun and Basil Homer and Mostafa Elhoushi and Yejin Lee and Shivaram Venkataraman and Dimitris Papailiopoulos and Carole-Jean Wu},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=xcDRx8vzCa}
}

License

See the LICENSE file.