/MINDER

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

MINDER

This is the official implementation for the paper "Multiview Identifiers Enhanced Generative Retrieval".
The preprint version is released in Arxiv.
If you find our paper or code helpful,please consider citing as follows:

@inproceedings{li-etal-2023-multiview,
    title = "Multiview Identifiers Enhanced Generative Retrieval",
    author = "Li, Yongqi  and Yang, Nan  and Wang, Liang  and Wei, Furu  and Li, Wenjie",
    booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = jul,
    year = "2023",
    publisher = "Association for Computational Linguistics",
    pages = "6636--6648",
}

Install

git clone https://github.com/liyongqi67/MINDER.git
sudo apt install swig
env CFLAGS='-fPIC' CXXFLAGS='-fPIC' res/external/sdsl-lite/install.sh
pip install -r requirements.txt
pip install -e .

Data

Please download all the data into the data folder.

  1. data/NQ folder. Please download biencoder-nq-dev.json, biencoder-nq-train.json, nq-dev.qa.csv, nq-test.qa.csv files into the NQ folder from the DPR repositiory.
  2. data/Trivia folder. Please download biencoder-trivia-dev.json, biencoder-trivia-train.json, trivia-dev.qa.csv, trivia-test.qa.csv files into the Trivia folder from the DPR repositiory.
  3. data/MSMARCO folder. Please download qrels.msmarco-passage.dev-subset.txt from this link.
  4. data/fm_index/ folder. Please download fm_index files psgs_w100.fm_index.fmi, psgs_w100.fm_index.oth for the Wikipedia corpus and msmarco-passage-corpus.fm_index.fmi, msmarco-passage-corpus.fm_index.oth for the MSMARCO corpus from this link.
  5. data/training_data/ folder.
    Download the NQ_title_body_query_generated from this link.
    Download the Trivia_title_body_query_generated from this link.
    Download the MSMARCO_title_body_query3 from this link.

Model training

We use the fairseq to train the BART_large model with the translation task.
The script for training on the NQ dataset is

    - fairseq-train
        data/training_data/NQ_title_body_query_generated/bin 
        --finetune-from-model /bart.large/model.pt 
        --arch bart_large 
        --task translation 
        --criterion label_smoothed_cross_entropy 
        --source-lang source 
        --target-lang target 
        --truncate-source 
        --label-smoothing 0.1 
        --max-tokens 4096 
        --update-freq 1 
        --max-update 800000 
        --required-batch-size-multiple 1
        --validate-interval 1000000
        --save-interval 1000000
        --save-interval-updates 15000 
        --keep-interval-updates 3 
        --dropout 0.1 
        --attention-dropout 0.1 
        --relu-dropout 0.0 
        --weight-decay 0.01 
        --optimizer adam 
        --adam-betas "(0.9, 0.999)" 
        --adam-eps 1e-08 
        --clip-norm 0.1 
        --lr-scheduler polynomial_decay 
        --lr 3e-05 
        --total-num-update 800000 
        --warmup-updates 500 
        --fp16 
        --num-workers 10 
        --no-epoch-checkpoints 
        --share-all-embeddings 
        --layernorm-embedding 
        --share-decoder-input-output-embed 
        --skip-invalid-size-inputs-valid-test 
        --log-format json
        --log-interval 100 
        --patience 5
        --find-unused-parameters
        --save-dir  ./

The script for training on the TriviaQA dataset is

    - fairseq-train
        data/training_data/Trivia_title_body_query_generated/bin 
        --finetune-from-model /bart.large/model.pt 
        --arch bart_large 
        --task translation 
        --criterion label_smoothed_cross_entropy 
        --source-lang source 
        --target-lang target 
        --truncate-source 
        --label-smoothing 0.1 
        --max-tokens 4096 
        --update-freq 1 
        --max-update 800000 
        --required-batch-size-multiple 1
        --validate-interval 1000000
        --save-interval 1000000
        --save-interval-updates 6000 
        --keep-interval-updates 3 
        --dropout 0.1 
        --attention-dropout 0.1 
        --relu-dropout 0.0 
        --weight-decay 0.01 
        --optimizer adam 
        --adam-betas "(0.9, 0.999)" 
        --adam-eps 1e-08 
        --clip-norm 0.1 
        --lr-scheduler polynomial_decay 
        --lr 3e-05 
        --total-num-update 800000 
        --warmup-updates 500 
        --fp16 
        --num-workers 10 
        --no-epoch-checkpoints 
        --share-all-embeddings 
        --layernorm-embedding 
        --share-decoder-input-output-embed 
        --skip-invalid-size-inputs-valid-test 
        --log-format json
        --log-interval 100 
        --patience 5
        --find-unused-parameters
        --save-dir  ./

The script for training on the MSMARCO dataset is

    - fairseq-train
        data/training_data/MSMARCO_title_body_query3/bin 
        --finetune-from-model /bart.large/model.pt 
        --arch bart_large 
        --task translation 
        --criterion label_smoothed_cross_entropy 
        --source-lang source 
        --target-lang target 
        --truncate-source 
        --label-smoothing 0.1 
        --max-tokens 4096 
        --update-freq 1 
        --max-update 100000 
        --required-batch-size-multiple 1
        --validate-interval 1000000
        --save-interval 1000000
        --save-interval-updates 6000 
        --keep-interval-updates 3 
        --dropout 0.1 
        --attention-dropout 0.1 
        --relu-dropout 0.0 
        --weight-decay 0.01 
        --optimizer adam 
        --adam-betas "(0.9, 0.999)" 
        --adam-eps 1e-08 
        --clip-norm 0.1 
        --lr-scheduler polynomial_decay 
        --lr 3e-05 
        --total-num-update 100000 
        --warmup-updates 500 
        --fp16 
        --num-workers 10 
        --no-epoch-checkpoints 
        --share-all-embeddings 
        --layernorm-embedding 
        --share-decoder-input-output-embed 
        --skip-invalid-size-inputs-valid-test 
        --log-format json
        --log-interval 100 
        --patience 3
        --find-unused-parameters
        --save-dir  ./

We trained the models on 8*32GB NVIDIA V100 GPUs. It took about 4d3h24m39s, 1d18h30m47s, 12h53m50s for training on NQ, TriviaQA, and MSMARCO, respectively.
We release our trained model checkpoints in this link.

Model inference

Please use the following script to retrieve passages for queries in NQ.

    - TOKENIZERS_PARALLELISM=false python seal/search.py 
      --topics_format dpr_qas --topics data/NQ/nq-test.qa.csv 
      --output_format dpr --output output_test.json 
      --checkpoint checkpoint_NQ.pt 
      --jobs 10 --progress --device cuda:0 --batch_size 20 
      --beam 15
      --decode_query stable
      --fm_index data/fm_index/stable2/psgs_w100.fm_index 

Please use the following script to retrieve passages for queries in TriviaQA.

    - TOKENIZERS_PARALLELISM=false python seal/search.py 
      --topics_format dpr_qas --topics data/Trivia/trivia-test.qa.csv
      --output_format dpr --output output_test.json 
      --checkpoint checkpoint_TriviaQA.pt 
      --jobs 10 --progress --device cuda:0 --batch_size 40 
      --beam 15
      --decode_query stable
      --fm_index data/fm_index/stable2/psgs_w100.fm_index

Please use the following script to retrieve passages for queries in MSMARCO.

    - TOKENIZERS_PARALLELISM=false python seal/search.py 
      --topics_format msmarco --topics Tevatron/msmarco-passage
      --output_format msmarco --output output_test.json 
      --checkpoint checkpoint_MSMARCO.pt 
      --jobs 10 --progress --device cuda:0 --batch_size 10 
      --beam 7
      --decode_query stable
      --fm_index data/fm_index/stable2/msmarco-passage-corpus.fm_index

Evaluation

Please use the following script to evaluate on NQ and TriviaQA.

    - python3 seal/evaluate_output.py
      --file output_test.json 

Please use the following script to evaluate on MSMARCO.

    - python3 seal/evaluate_output_msmarco.py
      data/MSMARCO/qrels.msmarco-passage.dev-subset.txt output_test.json

Acknowledgments

Part of the code is based on SEAL and sdsl-lite.

Contact

If there is any problem, please email liyongqi0@gmail.com. Please do not hesitate to email me directly as I do not frequently check GitHub issues.