/EfficientRAG-official

Code Repo for EfficientRAG: Efficient Retriever for Multi-Hop Question Answering

Primary LanguagePythonMIT LicenseMIT

EfficientRAG-official

Official code repo for EfficientRAG: Efficient Retriever for Multi-Hop Question Answering

Efficient RAG is a new framework to train Labeler and Filter to learn to conduct multi-hop RAG without multiple LLM calls.

Updates

  • 2024-09-12 open source the code

Setup

1. Installation

You need to install PyTorch >= 2.1.0 first, and then install dependent Python libraries by running the command

pip install -r requirements.txt

You can also create a conda environment with python>=3.9

conda create -n <ENV_NAME> python=3.9 pip
conda activate <ENV_NAME>
pip install -r requirements.txt

Preparation

  1. Download the dataset from HotpotQA, 2WikiMQA and MuSiQue. Separate them as train, dev and test set, and then put them under data/dataset.

  2. Download the retriever model Contriever and base model DeBERTa, put them under model_cache

  3. Prepare the corpus by extract documents and construct embedding.

python src/retrievers/multihop_data_extractor.py --dataset hotpotQA
python src/retrievers/passage_embedder.py \
    --passages data/corpus/hotpotQA/corpus.jsonl \
    --output_dir data/corpus/hotpotQA/contriever \
    --model_type contriever
  1. Deploy LLaMA-3-70B-Instruct with vLLM framework, and configure it in src/language_models/llama.py

2. Training Data Construction

We will use hotpotQA training set as an example. You could construct 2WikiMQA and MuSiQue in the same way.

2.1 Query Decompose

python src/data_synthesize/query_decompose.py \
    --dataset hotpotQA \
    --split train \
    --model llama3

2.2 Token Labeling

python src/data_synthesize/token_labeling.py \
    --dataset hotpotQA \
    --split train \
    --model llama3
python src/data_synthesize/token_extraction.py \
    --data_path data/synthesized_token_labeling/hotpotQA/train.jsonl \
    --save_path data/token_extracted/hotpotQA/train.jsonl \
    --verbose

2.3 Next Query Filtering

python src/data_synthesize/next_hop_query_construction.py \
    --dataset hotpotQA \
    --split train \
    --model llama
python src/data_synthesize/next_hop_query_filtering.py \
    --data_path data/synthesized_next_query/hotpotQA/train.jsonl \
    --save_path data/next_query_extracted/hotpotQA/train.jsonl \
    --verbose

2.4 Negative Sampling

python src/data_synthesize/negative_sampling.py \
    --dataset hotpotQA \
    --split train \
    --retriever contriever
python src/data_synthesize/negative_sampling_labeled.py \
    --dataset hotpotQA \
    --split train \
    --model llama
python src/data_synthesize/negative_token_extraction.py \
    --dataset hotpotQA \
    --split train \
    --verbose

2.5 Training Data

python src/data_synthesize/training_data_synthesize.py \
    --dataset hotpotQA \
    --split train

Training

Training Filter model

python src/efficient_rag/filter_training.py \
    --dataset hotpotQA \
    --save_path saved_models/filter

Training Labeler model

python src/efficient_rag/labeler_training.py \
    --dataset hotpotQA \
    --tags 2

Inference

EfficientRAG retrieve procedure

python src/efficientrag_retrieve.py \
    --dataset hotpotQA \
    --retriever contriever \
    --labels 2 \
    --labeler_ckpt <<PATH_TO_LABELER_CKPT>> \
    --filter_ckpt <<PATH_TO_FILTER_CKPT>> \
    --topk 10 \

Use LLaMA-3-8B-Instruct as generator

python src/efficientrag_qa.py \
    --fpath <<MODEL_INFERENCE_RESULT>> \
    --model llama-8B \
    --dataset hotpotQA

Citation

If you find this paper or code useful, please cite by:

@misc{zhuang2024efficientragefficientretrievermultihop,
      title={EfficientRAG: Efficient Retriever for Multi-Hop Question Answering},
      author={Ziyuan Zhuang and Zhiyang Zhang and Sitao Cheng and Fangkai Yang and Jia Liu and Shujian Huang and Qingwei Lin and Saravan Rajmohan and Dongmei Zhang and Qi Zhang},
      year={2024},
      eprint={2408.04259},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2408.04259},
}