/mamba.py

An efficient Mamba implementation in PyTorch and MLX.

Primary LanguagePython

mamba.py 🐍 : a simple and efficient Mamba implementation

A straightfoward implementation of Mamba in PyTorch with a simple parallel scan implementation, offering an major speedup over a sequential implementation, as the parallel scan allows the parallelization over the time dimension. It combines the ease of read with good performances.

speed comparison

This graph shows the training time (forward and backward pass) of a single Mamba layer (d_model=16, d_state=16) using 3 different methods : CUDA, which is the official Mamba implementation, mamba.py, which is this repo, and sequential, which is a sequential (RNN-like) implementation of the selective scan.

This repo contains a simple and readable code implementing the Mamba architecture in pure PyTorch as well as MLX. Its primary goal is educational.

a python and a mamba

The repo is organized as follows :

  • pscan.py : a PyTorch implementation of Blelloch's parallel scan
  • mamba.py : the Mamba model, as described in the paper. It is numerically equivalent (initialization, forward and backward pass).
  • mamba_lm.py : encapsulates a Mamba model in order to use it as a language model
  • 📁 mlx : basically the same code as above, but in MLX.
  • 📁 docs : a folder containing annotated explanations about the code, focusing on the parallel scan
  • 📁 examples : two examples of how to use the Mamba model in PyTorch.

Usage

The most basic usage is to use the Mamba object (mamba.py), which implements a simple Mamba model given a configuration. No embedding, no head : input is (B, L, D) and output is (B, L, D) as well.

import torch
from mamba import Mamba, MambaConfig

config = MambaConfig(d_model=16, n_layers=2)
model = Mamba(config)

B, L, D = 2, 64, 16
x = torch.randn(B, L, D)
y = model(x)

assert y.shape == x.shape

The class MambaLM (mamba_lm.py) builds on the Mamba object and offers a classic API for language models. It can be used as follows :

from mamba_lm import MambaLM, MambaLMConfig

config = MambaLMConfig(d_model=16, n_layers=4, vocab_size=32000)
model = MambaLM(config)

x = torch.randint(high=32000, size=(16, 64))
logits = model(x) # (B, L, vocab_size)

It simply encapsulates a Mamba object with an embedding layer, a final normalization and a language modeling head.

Examples

There are two basics examples available :

  • example_llm.ipynb : load a Mamba model with pretrained weights (from 130M to 2.8B from HuggingFace)
  • example_e2e_training.ipynb : an end-to-end training example where a Mamba model is employed as a world model for a simple 3-3 grid game (training is not completed, the model should be larger).

Performances

This section provides a more comprehensive performance comparison between mamba.py and the official Mamba implementation. Overall, as the first graph of this file shows, both have approximately the same asymptotic performance with respect to the sequence length. You can think as mamba.py as a regular Transformer implementation, while the official Mamba implementation is more like FlashAttention v1. Both have their owns advantages.

That being said, does the two implementations have the same asymptotic performances with respect to the other parameters ?

d_model asymptotic performances

a python and a mamba

We can see that both implementations behave the same as we increase d_model. The gap between the two stays roughly the same. (mamba.py is overall ~2x slower)

d_state asymptotic performances

a python and a mamba

This graph is important. We see that here, the asymptotic performance is not the same as we increase d_state. For a reminder, d_state, or $N$ in the paper, is the state expansion factor : each channel of the input is expanded into $N$ channels of the hidden state.

Does it matter in practice ? As of now, all the pretrained Mamba models (up to 2.8B parameters) used d_state=16, so this change of performance over d_state isn't important in this case. As d_state is not something that is supposed to grow (contrary to the seq length or d_model), this isn't a catastrophic result, but something to consider.

However, it is interesting to relate this observation with the claim made by Albert Gu and Tri Dao Mamba paper : The main idea is to leverage properties of modern accelerators (GPUs) to materialize the state ℎ only in more efficient levels of the memory hierarchy. They also describe (Annex D) the main data movements of their selective scan : working mainly in SRAM, they can reduce the memory reads/writes by a factor of $O(N)$. This explains what we're seeing here.

With d_state=16 (as in state-spaces/mamba-2.8b-slimpj), the gap between the two is relatively small, but with d_state=64 (currently not used in any models), the gap widens. (note the OOM on the second graph)

a python and a mamba

All the previous graph were computed with a batch size of 1, on a A100 80GB. It is a measure of both the forward and backward pass of a single Mamba block.

The previous analysis showed the importance of kernel fusion, which reduces the memory accesses by $O(N)$, which makes the whole process faster.

But memory requierement should also be considered : the official Mamba implementation uses recomputation in the backward pass : rather than keeping in memory the activations computed during the forward pass, it simply recomputes them in the backward pass, when needed. This greatly reduces the memory requierement of the Mamba model when doing training. This is not implemented in this repo.

Hence, this repo implements one of the three techniques mentionned in the Mamba paper that form the so called "hardware-aware selective scan" : the parallel scan. We say how kernel fusion impacts the speed while recomputation the memory requierements.

Sources and where to learn more

  • the Mamba paper : describes the Mamba architecture as implemented in this repo, which allows to model sequences in linear time.
  • the Mamba implementation, which is written in PyTorch but uses a parallel scan written in CUDA. This is the version that is the fastest.
  • a minimal PyTorch implementation of Mamba, which implements the scan operation as a sequential loop (its performance are a bit worse than the 'sequential' line in the first graph). This code closely follows this file from the officile Mamba implementation, but replaces the CUDA convolution with torch.nn.Conv1d, and the selective scan written in CUDA with a sequential loop. The code of this repo follows the structure of these 2 files.
  • Prefix Sums and Their Applications, by Guy E. Blelloch (1993).
  • Parallelizing Linear Recurrent Neural Nets Over Sequence Length : applies a parallel scan over the sequence in order to get rid of the sequential for-loop.
  • x.com/fchollet : original pscan implementation.

TODOs

  • docs
  • more tests with an increased d_model (add a Performances section)
  • a step function, used for (auto-regressive) inference.
  • a training function, similar to llama2.c

perfs :

  • unfold the for-loops in pscan.py to achieve better performance (see François Fleuret's pscan) (although this will sacrifice readability of bit)
  • write a reverse parallel scan specifically for the backward pass. (For now, we have to flip the array before and after the scan).
  • use torch.compile(). As far as I tested, it doesn’t work for now. It seems it isn’t happy with the custom PScan autograd function. Need to investigate. (see PR#1)