/neurocache

Neurocache: A library for augmenting language models with external caching mechanisms

Primary LanguagePythonApache License 2.0Apache-2.0

Neurocache

A library for augmenting language models with external caching mechanisms

GitHub release

Requirements

  • Python 3.6+
  • PyTorch 1.13.0+
  • Transformers 4.25.0+

Installation

pip install neurocache

Getting started

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from neurocache import (
    NeurocacheModelForCausalLM,
    OnDeviceCacheConfig,
)

model_name = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

cache_layer_idx = model.config.num_hidden_layers - 5

config = OnDeviceCacheConfig(
    cache_layers=[cache_layer_idx, cache_layer_idx + 3],
    attention_layers=list(range(cache_layer_idx, model.config.num_hidden_layers)),
    compression_factor=8,
    topk=8,
)

model = NeurocacheModelForCausalLM(model, config)

input_text = ["Hello, my dog is cute", " is cute"]
tokenized_input = tokenizer(input_text, return_tensors="pt")
tokenized_input["start_of_sequence"] = torch.tensor([1, 0]).bool()

outputs = model(**tokenized_input)

Supported model types

from neurocache.utils import NEUROCACHE_SUPPORTED_MODELS
print(NEUROCACHE_SUPPORTED_MODELS)

[
  "opt",
  "llama",
  "mistral",
  "gptj",
]

TODO

  • Benchmark the implementation and identify bottlenecks.
  • Add support for more models and for grouped query attention (for Mistral and Larger LLaMA models).
  • Add chunked storage function for generation (enables faster processing for long prompts).
  • Add support for masking padding tokens in the cache (required for global cache only).