📚 Paper • 🚀 Getting Started • ✏️ Documentations
LocalRQA is an open-source toolkit that enables researchers and developers to easily train, test, and deploy retrieval-augmented QA (RQA) systems using techniques from recent research. Given a collection of documents, you can use pre-built pipelines in our framework to quickly assemble an RQA system using the best off-the-shelf models. Alternatively, you can create your own training data, train open-source models using algorithms from the latest research, and deploy your very own local RQA system!
You can either install the package from GitHub or use our pre-built Docker image.
From GitHub
First, clone our repository
git clone https://github.com/jasonyux/LocalRQA
cd LocalRQA
Then run
pip install --upgrade pip
pip install -e .
From Docker
docker pull jasonyux/localrqa
docker run -it jasonyux/localrqa bash
our code base is located at /workspace/LocalRQA
.
In essence, a retrieval-augmented QA (RQA) system is composed of two parts:
- a document database (a collection of documents)
- a embedding model + a generative model
As a quick start, we provide a simple example to obtain a document database from a website, and build an RQA system using off-the-shelf models from huggingface. As a reference, the full example code can be found in demo.py
script at the root of the repository.
LocalRQA integrates with frameworks such as LangChain and LlamaIndex to easily ingest text data in various formats, such as JSON data, HTML data, data from Google Drive, etc. For example, you could load data from a website using SeleniumURLLoader
from langchain
, then save and parse them into a collection of documents (docs
):
from langchain_community.document_loaders import SeleniumURLLoader
from langchain.text_splitter import CharacterTextSplitter
from local_rqa.text_loaders.langchain_text_loader import LangChainTextLoader
# specify how to load the data and how to chunk them
# note: this requires selenium to read the web page
# if your selenium is not working, you can SKIP this entire section.
# We have already provided the `example/demo/databricks_web.pkl` file in this repo.
loader_func, split_func = SeleniumURLLoader, CharacterTextSplitter
loader_parameters = {'urls': ["https://docs.databricks.com/en/dbfs/index.html"]}
splitter_parameters = {'chunk_size': 400, 'chunk_overlap': 50, 'separator': "\n\n"}
kwargs = {"loader_params": loader_parameters, "splitter_params": splitter_parameters}
# load the data, chunk them, and save them
docs = LangChainTextLoader(
save_folder="example/demo", # where data is saved
save_filename="documents.pkl",
loader_func=loader_func,
splitter_func=split_func
).load_data(**kwargs)
this list of documents (docs
) is now your document database, which will be used to create an embedding index for the RQA system.
Given a path to a document database (see above), we can directly use SimpleRQA
to 1) create and save an embedding index if example/index
is empty, 2) plugin an embedding model and a generative model, and 3) run QA!
from local_rqa.pipelines.retrieval_qa import SimpleRQA
from local_rqa.schema.dialogue import DialogueSession
rqa = SimpleRQA.from_scratch(
document_path="example/demo/databricks_web.pkl",
index_path="example/demo/index",
embedding_model_name_or_path="intfloat/e5-base-v2", # embedding model
qa_model_name_or_path="lmsys/vicuna-7b-v1.5" # generative model
)
response = rqa.qa(
batch_questions=['What is DBFS?'],
batch_dialogue_session=[DialogueSession()],
)
print(response.batch_answers[0])
# DBFS stands for Databricks File System, which is a ...
Different from other frameworks, LocalRQA features methods to locally train/test your RQA system using methods curated from the latest research. We thus provide a large collection of training and (automatic) evaluation methods to help users easily develop new RQA systems. For a list of supported training algorithms, please refer to our documentation website.
As a simple example, below is an example script using simple SFT to train mistralai/Mistral-7B-Instruct-v0.2
:
python scripts/train/qa_llm/train_w_gt.py \
--use_flash_attention true \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--deepspeed scripts/train/ds_config.json \
--learning_rate 5e-6 \
--num_train_epochs 2 \
--gradient_accumulation_steps 2 \
--bf16 true \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--assistant_prefix [/INST] \
--user_prefix "<s>[INST]" \
--sep_user " " \
--sep_sys "</s>" \
--eval_embedding_model intfloat/e5-base-v2 \
--logging_steps 10 \
--eval_steps 30 \
--save_steps 30 \
--output_dir model_checkpoints/databricks_exp \
--run_group databricks \
--train_file example/databricks/processed/train_w_qa.jsonl \
--eval_file example/databricks/processed/eval_w_qa.jsonl \
--test_file example/databricks/processed/test_w_qa.jsonl \
--full_dataset_file_path example/databricks/database/databricks.pkl \
--full_dataset_index_path example/databricks/database/index
LocalRQA provides two methods to showcase your RQA system to external users: 1) a static evaluation webpage where users can directly assess the system’s performance using a test dataset, or 2) an interactive chat webpage where users can chat with the system and provide feedback for each generated response.
To evaluate the first 50 predictions from a prediction file (e.g., produced by our training/evaluation script), run:
python local_rqa.serve.gradio_static_server.py \
--file_path <path/to/your/test-predictions.jsonl> /
--include_idx 1-50
To host your model and launch an interactive chat webpage, you will need to start a model worker (hosting your models), and a model controller (dealing with user requests):
- run
python open_rqa.serve.controller.py
- launch your customized RQA system(s):
export CUDA_VISIBLE_DEVICES=0 python open_rqa.serve.model_worker.py \ --document_path example/databricks/database/databricks.pkl \ --index_path example/databricks/database/e5-v2-index \ --embedding_model_name_or_path intfloat/e5-base-v2 \ --qa_model_name_or_path lmsys/vicuna-7b-v1.5 \ --model_id simple_rqa
- To do a quick test to see if the above is working, try
python local_rqa.serve.test_message.py --model_id simple_rqa
- Launch your demo page!
where the
python local_rqa.serve.gradio_web_server.py \ --model_id simple_rqa \ --example "What is DBFS? What can it do?" \ --example "What is INVALID_ARRAY_INDEX?"
--model_id simple_rqa
is to let the controller know which model this demo page is for, and the--example
are the example questions that will be shown on the demo page.
For more details on model serving, please refer to our documentation website.