/filco

[Preprint] Learning to Filter Context for Retrieval-Augmented Generaton

Primary LanguagePythonCreative Commons Attribution Share Alike 4.0 InternationalCC-BY-SA-4.0

FilCo

This repository contains the code and data about the project: Learning to Filter Context for Retrieval-Augmented Generation

Install

Install all required libraries by running

pip install -r requirements.txt

Retrieve top relevant Wikipedia passages using Dense Passage Retriever (DPR) and store into the ./datasets/${name} directory. We also provide preprocessed datasets with top-5 retrieved passages (here). We specify ${name} for six datasets with ['nq', 'tqa', 'hotpotqa', 'fever', 'wow'] in following example commands.

Measure Retrieved Passages

Before filtering out potentially redundant context, we need to measure the utility scores of individual spans in the retrieved passages. You can use any of the three context filtering strategies: (i) entailment, (ii) lexical overlap, and (iii) conditional cross-mutual information (CXMI).

Use measure_ctxs.py to measure the utility score of each retrieved passage, as well as individual sentences within, for example:

python measure_ctxs.py \
--dataset_path "./datasets/nq/base/test.json" \
--output_path  "./datasets/nq/scored/test.json" \
--metric_name  "strinc" "lexical" "cxmi" \
--n_contexts 5 \
--prefix "Given the ['context', 'question'], predict the answer to the question:"

If "cxmi" is specified as one of the metric_names, make sure you specify the huggingface model to use in model_name_or_path. Or it will use "google/flan-t5-xl" by default.

Obtain Training & Testing Data

Use get_inputs.py to create input-output training pairs for both the context filtering model $M_{ctx}$ and generation model $M_{gen}$.

For the context filtering task, the input should be all top-K retrieved passages, and the output is context filtered with one of the three strategies.

python get_inputs.py \
--dataset_path "./datasets/nq/scored/train.json" \
--output_path "./datasets/nq/mctx/em/train_em_top1.json" \
--input_list question passage --output_list filtered \
--n_examples 0 --n_contexts 1 \
--filter_criteria strinc --print_example

Alter the value of n_examples to include more in-context examples. Adjust the value of n_contexts to change the number of retrieved passages involved. filter_criteria specifies which filtering strategy you want to use, among ['strinc', 'lexical', 'cxmi'].

For the generation task, the input should be filtered context, and output is the annotated output.

python get_inputs.py \
--dataset_path "./datasets/nq/scored/train.json" \
--output_path "./datasets/nq/mgen/em/train_em_top1.json" \
--input_list question filtered --output_list answer \
--n_examples 0 --n_contexts 1 \
--filter_criteria strinc --print_example

The only changes to the context filtering case is the input_list and output_list, where we switched the input context to from entire passages ('passage') to filtered sentences ('filtered').

Training A Context Filtering Model

Perform the above processing on training, validation, and test data, then to fine-tune a FlanT5 (xl) model using train.py, which passes in "google/flan-t5-xl" to the model_name_or_path argument by default.

python train.py \
--train_data_path "./datasets/nq/mctx/em/train_em_top1.json" \
--eval_data_path "./datasets/nq/mctx/em/dev_em_top1.json" \
--test_data_path "./datasets/nq/mctx/em/test_em_top1.json" \
--output_dir "./checkpoints/nq-mctx_filco-em" \
--do_train --do_eval --do_predict

After training, load the fine-tuned checkpoint to predict filtered context for testing examples.

python query.py \
--dataset_path "./datasets/nq/mctx/em/test_em_top1.json" \
--output_path "./output/nq/mctx/filco-em_tuned-ft5.json" \
--model_name_or_path "./checkpoints/nq-mctx_filco-em"

After this, convert the dataset to generation example format by

python replace_context.py \
--dataset_path "./datasets/nq/base/test.json" \
--predset_path "./output/nq/mctx/filco-em_tuned-ft5.json" \
--output_path "./datasets/nq/mgen/em/test_em_top1_predict-ft5.json" \
--process_dataset nq

To train and query LLaMa models, switch the model name to "meta-llama/Llama-2-7b-hf". Alternatively using xTuring, run train_llama.py and query_llama.py with similar arguments, but transform the examples into instruction style using convert_dataset.py.

Training A Generation Model with Filtered Context

Prepare the training and validation data using the same method, then train Flan-T5 models using train.py and LLaMa models with train_llama.py.

python train.py \
--train_data_path "./datasets/nq/mgen/em/train_em_top1.json" \
--eval_data_path "./datasets/nq/mgen/em/dev_em_top1.json" \
--test_data_path "./datasets/nq/mgen/em/test_em_top1.json" \
--output_dir "./checkpoints/nq-mgen_filco-em" \
--do_train --do_eval --do_predict

To use the tuned model checkpoint for inference, run

python query.py \
--dataset_path "./datasets/nq/mgen/em/test_em_top1.json" \
--output_path "./output/nq/mgen/silver-em_tuned-ft5.json" \
--model_name_or_path "./checkpoints/nq-mgen_filco-em"

Switch the silver filtered context (e.g., "./datasets/nq/mgen/em/train_em_top1.json") to model filtered context (e.g., "./output/nq/mctx/filco-em_tuned-ft5.json") to experiment in the FilCo setting.

Evaluating Filtering and Generation Models

To evaluate the generation performance, use the EM (~Accuracy) or F1 according to the task formulation.

python eval.py \
--dataset_path "./datasets/nq/base/test.json" \
--predset_path "./output/nq/mgen/silver-em_tuned-ft5.json" \
--metric_name "em"

Reference

If you find our paper or code useful, please cite the paper

@article{wang2023learning,
  title={Learning to Filter Context for Retrieval-Augmented Generation},
  author={Zhiruo Wang, Jun Araki, Zhengbao Jiang, Md Rizwan Parvez, Graham Neubig},
  journal={arXiv preprint arXiv:2311.08377},
  year={2023}
}