/AugTriever

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause

AugTriever: Unsupervised Dense Retrieval by Scalable Data Augmentation

This repository contains the code and models of the paper "AugTriever: Unsupervised Dense Retrieval by Scalable Data Augmentation"

Our code is based on the following repositories:

Data

AugQ-Wiki and AugQ-CC can be downloaded from Huggingface Hub.

Checkpoints

Naming corresponds to Table 1 in the paper.

Aug method Model MM BEIR (14 tasks) Download Link
Hybrid-TQGen+ MoCo 24.6 41.1 [download]
Hybrid-All MoCo 23.5 39.4 [download]
Hybrid-TQGen MoCo 23.3 39.4 [download]
Doc-Title MoCo 21.8 38.7 [download]
QExt-PLM MoCo 20.6 38.2 [download]
TQGen-Topic MoCo 21.2 38.9 [download]
TQGen-Title MoCo 21.8 39.3 [download]
TQGen-AbSum MoCo 23.2 39.6 [download]
TQGen-ExSum MoCo 23.0 39.4 [download]
TQGen-Topic InBatch 20.7 39.0 [download]

Run Training

A few scripts for starting training are placed in the folder examples/traning. For example:

cd $PATH_TO_REPO
sh examples/training/cc.moco.topic50.bs2048.gpu8.sh

Run Evaluation

BEIR

Please refer to BEIR for data download.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python torch.distributed.launch --nproc_per_node=8 --master_addr=127.0.0.1 --master_port=2255 eval_beir.py --model_name_or_path output_dir/augtriever-release/cc.T03b_title50.moco-2e14.contriever256-special50.bert-base-uncased.avg.dot.q128d256.step100k.bs1024.lr5e5/ --dataset fiqa --metric dot --pooling average --per_gpu_batch_size 128 --beir_data_path data/beir/ --output_dir eval_dir/beir

ODQA

Please refer to Spider for details about QA data download and processing.

export EXP_DIR="output_dir/cc-hybrid.RC20+T0gen80.seed477.moco-2e14.contriever256-special50.bert-base-uncased.avg.dot.q128d256.step200k.bs2048.lr5e5/"
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python torch.distributed.launch --nproc_per_node=8 --master_port=31133 --max_restarts=0 generate_passage_embeddings.py --model_name_or_path $EXP_DIR --output_dir $EXP_DIR/embeddings --passages data/nq/psgs_w100.tsv --per_gpu_batch_size 512
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python eval_qa.py --model_name_or_path facebook/contriever --passages data/nq/psgs_w100.tsv --passages_embeddings "$EXP_DIR/embeddings/*" --qa_file data/nq/qas/*-test.csv,data/nq/qas/entityqs/test/P*.test.json --output_dir $EXP_DIR/qa_output --save_or_load_index

Convert models

Convert to Huggingface BERT model

python convert_checkpoint_to_hf_bert.py --ckpt_path output_dir/cc.T03b_topic.inbatch.contriever256-special.bert-base-uncased.avg.dot.q128d256.step100k.bs1024.lr5e5 --output_dir output_dir/cc.T03b_topic.inbatch.contriever256-special.bert-base-uncased.avg.dot.q128d256.step100k.bs1024.lr5e5/hf_ckpt_bert --model_type shared

Convert to Huggingface DPR model

python convert_checkpoint_to_hf_dpr.py --ckpt_path output_dir/cc-hybrid.RC20+T0gen80.seed477.moco-2e14.contriever256-special50.bert-base-uncased.avg.dot.q128d256.step200k.bs2048.lr5e5 --output_dir output_dir/cc-hybrid.RC20+T0gen80.seed477.moco-2e14.contriever256-special50.bert-base-uncased.avg.dot.q128d256.step200k.bs2048.lr5e5/hf_ckpt_dpr --model_type shared

Export scores

Replace the exp path in gather_score_beir.py/gather_score_qa.py/gather_score_senteval.py and run it. For example

python gather_score_beir.py

License

AugTriever is licensed under the BSD 3-Clause License.

Evaluation codes that are forked from external repositories are placed in subfolders (e.g. src/beir, src/beireval, src/mteb, src/mtebeval, src/qa, src/senteval). Please refer to LICENSE in each subfolder for their Copyright information.

Citation

If you find the AugTriever code or models useful, please cite it by using the following BibTeX entry.

@article{meng2022augtriever,
  title={AugTriever: Unsupervised Dense Retrieval by Scalable Data
Augmentation},
  author={Meng, Rui and Liu, Ye and Yavuz, Semih and Agarwal, Divyansh and Tu, Lifu and Yu, Ning and Zhang, Jianguo and Bhat, Meghana and Zhou, Yingbo},
  journal={arXiv preprint arXiv:2212.08841},
  year={2022}
}