/mlx-rag-gguf

Minimal, clean code implementation of RAG with mlx using gguf model weights

Primary LanguagePython

MLX RAG With GGUF Model Weights

Minimal, clean code implementation of RAG with mlx using gguf model weights.

The code here builds on https://github.com/vegaluisjose/mlx-rag, it has been optimized to support RAG-based inferencing for .gguf models. I am using BAAI/bge-small-en for the embedding model, TinyLlama-1.1B-Chat-v1.0-GGUF (you can choose from supported models below) as base model and the custom vector database script for indexing texts in a pdf file. Inference speeds can go up to ~413 tokens/sec for prompts and ~36 tokens/sec for generation on my 8G M2 Air.

Update

  • Added support for phi-3-mini-4k-instruct.gguf and other Q4_0, Q4_1 & Q8_0 quantized models, download and save model in models/phi-3-mini-instruct folder

Demo

mlx_rag_gguf_demo.mp4

Usage

Download Models (you can use hf's snapshot_download but I recommend downloading separately to save time). Save in models folder.

Note

MLX currently only support a few quantizations: Q4_0, Q4_1, and Q8_0. Unsupported quantizations will be cast to float16.

Tested/Supported models

Tinyllama Q4_0 and Q8_0

Phi-3-mini Q4_0

Mistral Q4_0 and Q8_0

Embedding models

  • mlx-bge-small-en converted mlx format of BAAI/bge-small-en, save it in the mlx-bge-small-en folder.
  • bge-small-en Only need the model.safetensors file, save it in the bge-small-en folder.

Install requirements

python3 -m pip install -r requirements.txt

Convert pdf into mlx compatible vector database

python3 create_vdb.py --pdf mlx_docs.pdf --vdb vdb.npz

Query the model

python3 rag_vdb.py \
    --question "Teach me the basics of mlx" \
    --vdb "vdb.npz" \
    --gguf "models/phi-3-mini-instruct/phi-3-mini-4k-instruct.Q4_0.gguf"

The files in the repo work as follow:

  • gguf.py: Has all stubs for loading and inferencing .gguf models.
  • vdb.py: Holds logic for creating a vector database from a pdf file and saving it in mlx format (.npz) .
  • create_vdb.py: It inherits from vdb.py and has all arguments used in creating a vector DB from a PDF file in mlx format (.npz).
  • rag_vdb.py: Retrieves data from vdb used in querying the base model.
  • model.py: Houses logic for the base model (with configs), embedding model and transformer encoder.
  • utils.py: Utility function for accessing GGUF tokens.

Queries make use of both .gguf (base model) and .npz (retrieval model) simultaneouly resulting in much higher inferencing speeds.

Checkout other cool mlx projects here: ml-explore/mlx#654 (comment)

License

MIT