/rebased

Official implementation of the paper "Linear Transformers with Learnable Kernel Functions are Better In-Context Models"

Primary LanguagePythonApache License 2.0Apache-2.0

Linear Transformers with Learnable Kernel Functions are Better In-Context Models

About

The recent Based model emerged as a hybrid solution, blending a Linear Transformer with a kernel inspired by the Taylor expansion of exponential functions, augmented by convolutional networks. Mirroring the Transformer's in-context adeptness, it became a strong contender in the field. In our work, we present a singular, elegant alteration to the Based kernel, ReBased, that amplifies its in-context learning and overall language modeling process. This repository is a modification of the original Zoology codebase. Also we provide Triton kernels for ReBased, built on Flash Linear Attention repository.

Based ReBased
Based architecture ReBased architecture

Launching experiments

Ensure you have torch and all other dependencies listed in requirements.txt installed. Then install flash linear attention kernels:

pip install -r requirements.txt
cd flash_linear_attention && FLA_SKIP_CUDA_BUILD=TRUE pip install -e .

Fill missing Wandb key, entity and project ids in config files. Now you can start with running one simple MQAR experiment on sequence length 48:

export WANDB_API_KEY="API_KEY"
python3 src/train.py --config_path configs/simple_mqar.yaml

Then, you can reproduce Figure 1 from our paper by running the sweep. We used wandb agent command to run sweeps:

wandb sweep configs/main_fig_sweep.yaml
wandb agent entity/project/id

When the sweep is finished, run all the commands in Jupyter notebook to obtain the plot and get results for the MQAR dataset for Based, Rebased, RWKV, Mamba, ConvAttention architectures, different model dimensions and sequence lengths: Main figure

Citation

You can cite ReBased as:

@article{rebased,
      title={Linear Transformers with Learnable Kernel Functions are Better In-Context Models}, 
      author={Yaroslav Aksenov and Nikita Balagansky and Sofia Maria Lo Cicero Vaina and Boris Shaposhnikov and Alexey Gorbatovski and Daniil Gavrilov},
      year={2024},
      eprint={2402.10644},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

This repo was developed on top of Zoology and FLA Triton kernels. Cite them too:

@article{zoology2023,
  title={Zoology: Measuring and Improving Recall in Efficient Language Models},
  author={Arora, Simran and Eyuboglu, Sabri and Timalsina, Aman and Johnson, Isys and Poli, Michael and Zou, James and Rudra, Atri and Ré, Christopher},
  journal={	arXiv:2312.04927},
  year={2023}
}

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}