/mamba-pytorch

Mamba model implementation using only pytorch as a dependency

Primary LanguagePythonApache License 2.0Apache-2.0

Mamba for CPU

Mamba

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
Paper: https://arxiv.org/abs/2312.00752

About

Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.

Installation

Not needed

Pretrained Models

Pretrained models are uploaded to Hugging Face: mamba-130m, mamba-370m, mamba-790m, mamba-1.4b, mamba-2.8b, trained on 300B tokens on the Pile, as well as mamba-2.8b-slimpj (trained on 600B tokens on the SlimPajama dataset).

The models will be autodownloaded by the generation script below.

These models were trained on the Pile, and follow the standard model dimensions described by GPT-3 and followed by many open source models:

Parameters Layers Model dim.
130M 24 768
370M 48 1024
790M 48 1536
1.4B 48 2048
2.8B 64 2560

(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)

Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.

Inference

The script mamba_generate.py

  1. autoloads a model from the Hugging Face Hub,
  2. generates completions of a user-specified prompt,
  3. measures the inference speed of this generation.

Example

python mamba_generate.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python mamba_generate.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

Troubleshooting

Precision

Our models were trained using PyTorch AMP for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).

We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, as a first step please try a framework storing parameters in fp32 (such as AMP).

Initialization

Some parts of the model have initializations inherited from prior work on S4 models. For example, the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in nn.Linear modules to zero). If this is the case, you may have to add custom logic (e.g. this line turns off re-initializing in our trainer, but would be a no-op in any other framework) that is specific to the training framework.

Citation

If you use this codebase, or otherwise found our work valuable, please cite Mamba:

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}