Kira is a suite of LLMs built with JAX and Equinox.
It is designed to be as clean and simple as possible, while still being flexible and powerful. It also provides a simple training loop, which interoperates with the Jaxonloader library.
Currently, Kira provides the following models:
Kira
: A standard transformer, which allows for interpolation between MHA and MQA.Mamba
: The new selective state space model
Kira can also be used as an encoder. Simply pass mask=None
when you call Kira
and there will be no masking in the MHA (i.e. making it an encoder).
To get started with Kira, you can either install it with
pip3 install kira_llm
or simply clone the repository and cherry-pick what you need.