Simplex Random Feature attention, in PyTorch
Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference, with a little fine tuning! Smarter men than I proceeded to publish this, this, and that. This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. Just drop the Attention module into your code, in place of your SDP implementation, and fine tune under the ordinary training objective.
Dropping that pesky KV cache from
First, do the appropriate model surgery. Then, resume the original training objective. Here's a dataset we used internally for a Llama 2 retrofit that's now in production.
pip install git+https://github.com/notarussianteenager/srf-attention
import torch
from srf_attention import Attention
device = 'cpu'
B, H, L, D = (1, 8, 1024, 128)
q, k, v = [torch.randn(B, H, L, D) for _ in range(3)]
# CHUNK_SIZE controls the memory consumption of the attention computation
CHUNK_SIZE=256
# Simplex Random Feature (SRF) Attention module
# All intermediate computations done in FP32, but cached values are FP16.
# Recomputes the attention matrix in the backward pass instead of storing it:
attn = Attention(d=D, n_features=D, causal=True, device=device)
# Use 1 instance for each layer,
# and disable auto-redraw of random features prior to beginning training:
attn.redraw_on_call_(False)
# During fine-tuning, replace your softmax attention function with this:
o = attn(q, k, v, mode='train', attn_fn='torch', chunk_size=CHUNK_SIZE)
# On each training step, call redraw_() FIRST to resample the random features:
attn.redraw_()
# That's it! Now just fine-tune.
Here's an example, using the HF Transformers diff we wrote to retrofit Llama 2 with SRF attention:
# Make sure TILE_SIZE env var is set, we use TILE_SIZE=256
import torch
# install using `pip install git+https://github.com/notarussianteenager/transformers-llama-srf`
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
### Enable SRF,
### Disable random feature auto-redraw
for module in model.modules():
if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
module.use_fast_attn_(True)
module.attn_fn.redraw_on_call_(False)
### Utility function for resampling features
def resample_rfs(model):
for module in model.modules():
if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
module.attn_fn.redraw_(next(model.parameters()).device)
### Pseudo-code:
optimizer = YourOptimizerHere()
for step, batch in enumerate(imaginary_dataset):
inputs, targets = batch
# Always resample random features manually,
# because auto-resampling causes issues with checkpointing
resample_rfs(model)
outputs = model(inputs)
logits = outputs.logits.reshape(-1, outputs.logits.shape[-1])
loss = torch.nn.functional.cross_entropy(logits, targets['input_ids'].reshape(-1))
loss.backward()
optimizer.step()