/Linearized-LLM

[ICML 2024] When Linear Attention Meets Autoregressive Decoding: Towards More Effective and Efficient Linearized Large Language Models

Primary LanguagePythonApache License 2.0Apache-2.0

When Linear Attention Meets Autoregressive Decoding: Towards More Effective and Efficient Linearized Large Language Models

License: Apache 2.0

Haoran You, Yichao Fu, Zheng Wang, Amir Yazdanbakhsh, Yingyan (Celine) Lin

Accepted by ICML 2024. More Info: [ Paper | Github ]


News 🔥🔥 !

  • [ ✅ New ] Jun. 11, 2024. 💥 Release our trained LLaMA-2-7B model checkpoints on Huggingface!
  • [ ✅ New ] Jun. 11, 2024. 💥 Linearized-LLM's PyTorch implementation codes are released!

Table of Content

Brief Introduction

Basic Usage

Train Your Own Linerized-LLM

Citation & Acknowledgement

Basic Usage

The main implementation can be found in the autoregressive_wrapper.py and flash_pytorch.py files. The code is adapted from FLASH.

Set up Environment

Please set up the environment using the following commands and ensure that CUDA is included in your PATH:

export PATH=/PATH-TO-CUDA/:$PATH
conda create -n LinearLLM python==3.10
conda activate LinearLLM
pip install -r requirements.txt
pip install flash-attn

Download Trained Models

We provide our trained model checkpoints at this HuggingFace repository. Follow the bash script below to download the model:

# Linearized LLaMA-2 weights
huggingface-cli download LinearizedLLM/llama-2-7b-aug-linear --local-dir llama-2-7b-aug-linear

# Medusa Head for Linearized LLaMA-2 weights
huggingface-cli download LinearizedLLM/llama-2-7b-medusa-head-aug-linear --local-dir llama-2-7b-medusa-head-aug-linear

Reproduce Results

To reproduce Table 8 from the paper, which demonstrates the speedup of augmented linearized LLaMA-2 with speculative decoding, use the following bash script. The code is adapted from the Medusa repository.

cd experiments
bash run_medusa.sh

To reproduce Table 4, which shows latency and memory improvements with our augmented linear attention, use the following bash script. Note that we use transformers==4.37.0.

pip install transformers==4.37.0
cd experiments
bash run_benchmark.sh

Train Your Own Linearized-LLM

FLASH Training from Scratch

Use the bash script below to train a 24-layer FLASH Model from scratch:

bash runall-125k.sh

T5 Fine-tuning

Use the bash script below to finetune T5 with augmented linear attention. The code is adapted from the transformers repository.

cd experiments
bash tasks_run-t5.sh

GPT-2 Fine-tuning

Use the bash script below to finetune GPT-2 with augmented linear attention. The code is adapted from the transformers repository.

cd experiments
bash tasks_run-gpt2.sh

LLaMA-2 Fine-tuning

Use the bash script below to finetune LLaMA-2 with augmented Linear Attention. The code is adapted from the LongLoRA repository.

cd experiments
bash tasks_run-llama2.sh

Citation & Acknowledgement

@inproceedings{you2024linear,
  title={When Linear Attention Meets Autoregressive Decoding: Towards More Effective and Efficient Linearized Large Language Models},
  author={You, Haoran and Fu, Yichao and Wang, Zheng and Yazdanbakhsh, Amir and Lin, Yingyan (Celine)},
  booktitle={Proceedings of the 41st International Conference on Machine Learning (ICML 2024)},
  year={2024},
}

Thanks to the developers of FLASH, transformers, LongLoRA, and Medusa for providing their codebases!