Source code for our EMNLP 2023 Findings Paper Boot and Switch: Alternating Distillation for Zero-Shot Dense Retrieval.
pip install -r requirements.txt
- fanjiang98/ABEL-Query-Encoder-Warmup: warm-up query encoder.
- fanjiang98/ABEL-Passage-Encoder-Warmup: warm-up passage encoder.
- fanjiang98/ABEL-Query-Encoder: query encoder.
- fanjiang98/ABEL-Passage-Encoder: passage encoder.
mkdir -p beir
cd beir
wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip
unzip scifact.zip
cd ../
Other datasets can be downloaded from here. Note that BioASQ, TREC-NEWS, Robust04 and Signal-1M are not publicly available, please refer to here for more details.
python eval_beir.py \
--data_dir beir \
--task scifact \
--query_encoder_path fanjiang98/ABEL-Query-Encoder \
--passage_encoder_path fanjiang98/ABEL-Passage-Encoder \
--max_seq_length 512 \
--pooling_mode mean
wget https://homes.cs.washington.edu/~akari/tart/cross_task_cross_domain_final.zip
unzip cross_task_cross_domain_final.zip
python eval_x2.py \
--data_dir corss_task_cross_domain_final \
--task ambig \
--query_encoder_path fanjiang98/ABEL-Query-Encoder \
--passage_encoder_path fanjiang98/ABEL-Passage-Encoder
mkdir -p ODQA
cd ODQA
wget https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
gunzip -d psgs_w100.tsv.gz
cd ../
Encode Query
MODEL_PATH=fanjiang98/ABEL-Query-Encoder
DATA_PATH=ODQA
mkdir -p ${DATA_PATH}/encoding
python encode.py \
--model_name_or_path ${MODEL_PATH} \
--output_dir ${MODEL_PATH} \
--train_dir ${DATA_PATH} \
--fp16 \
--per_device_eval_batch_size 2048 \
--encode_is_qry \
--shared_encoder False \
--max_query_length 128 \
--query_file nq-test.qa.csv \
--dataloader_num_workers 4 \
--encoded_save_path ${DATA_PATH}/encoding/nq_test_query_embedding.pt
Encode Corpus
MODEL_PATH=fanjiang98/ABEL-Passage-Encoder
DATA_PATH=ODQA
mkdir -p ${DATA_PATH}/encoding
for i in $(seq -f "%02g" 0 19)
do
python encode.py \
--model_name_or_path ${MODEL_PATH} \
--output_dir ${MODEL_PATH} \
--train_dir ${DATA_PATH} \
--fp16 \
--corpus_file psgs_w100.tsv \
--shared_encoder False \
--max_passage_length 512 \
--per_device_eval_batch_size 128 \
--encode_shard_index $i \
--encode_num_shard 20 \
--dataloader_num_workers 4 \
--encoded_save_path ${DATA_PATH}/encoding/embedding_split${i}.pt
done
DATA_PATH=ODQA
python retriever.py \
--query_embeddings ${DATA_PATH}/encoding/nq_test_query_embedding.pt \
--passage_embeddings ${DATA_PATH}/encoding/'embedding_split*.pt' \
--depth 100 \
--batch_size -1 \
--save_text \
--save_ranking_to ${DATA_PATH}/nq.rank.txt
We use pyserini for evaluation:
python convert_result_to_trec.py --input ${DATA_PATH}/nq.rank.txt --output ${DATA_PATH}/nq.rank.txt.trec
python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run \
--topics dpr-nq-test \
--index wikipedia-dpr \
--input nq.rank.txt.trec \
--output run.nq.test.json
python -m pyserini.eval.evaluate_dpr_retrieval
--retrieval run.nq.test.json \
--topk 1 5 20 100
Please download the training data from OneDrive and put them on corresponding directories under beir
.
Tokenizer query and corpus
python tokenize_query_passage.py
We use slurm for training on 8 80G A100:
bash scripts/beir/train_biencoder_sent_crop_beir.sh