conda create -n t5qr python=3.8
source activate t5qr
pip install -r requirements.txt
Suppose that the current query is
We provide examples of training/dev file and test file in the data
folder.
// trainig file:
// "cur_utt_text" is the current query. "oracle_utt_text" is the oracle query. "ctx_utts_text" is the previous user query context (i.e., [q_1, q_2, ...]) and "ctx_resps_text" is the previous agent responses context (i.e., [r_1, r_2, ...])
{"sample_id": "QReCC-Train_7625_2", "cur_utt_text": "When was the movie released", "oracle_utt_text": "When was the moview Amadeus released", "ctx_utts_text": ["how does mozart die in the movie amadeus"], "ctx_resps_text": ["Mozart suddenly came down with fever and was wracked with pain.In the following days his health significantly deteriorated. He died on December 5 after lapsing into a coma."]}
// test file:
// The test file format is the same as that of training file but does not need the "oracle_utt_text" field.
bash run_train.sh
We randomly split the training set of QReCC to new training (90%) and dev(10%) sets and use them to train a T5 rewriter model.
You can set the maximum training epochs to num_train_epochs
. But the early stopping will be triggered when the dev loss does not decresae for two consecutive periods of model saving. Therefore, you should use the third-to-last saved model as the final model.
bash run_inference.sh
We support using DDP for inference on multi GPUs. The corresponding rewrite is in the "t5_rewrite"
field of the output file.
Note that: for QReCC, set:
max_response_length = 100
max_seq_length = 384
While for CAsT-19 and 20, set:
max_response_length=128 # As the **last** automatic canonical response in CAsT-20 is a longer passage compared with the responses in QReCC which are shorter text span.
max_seq_length=256 # CAsT-20 only include one response and CAsT-19 does not include response, so its maximum sequence length can be shorter than that of QReCC.
We provide a "quick" script (i.e., single_rewrite.py
) to perform a single rewriting.
You need to set the rewriter path
, cur_utt_text
(i.e., the current user query), ctx_utts_text
(i.e., the previous user queries), and ctx_resps_text
(i.e., the previous agent responses) in the script, and then run:
python single_rewrite.py
After conversational query rewriting, we can evaluate the rewrites by directly comparing it with the manual oracle rewrites or on downstream ranking tasks.
Currently, the evaluation is mainly conducted on QReCC test set, CAsT-19, and CAsT-20 test sets.
We provide the evaluation instruction in evaluation
folder.