/T5QR

Primary LanguagePython

An Implementation for T5-based Conversational Query Rewriter

Environment

conda create -n t5qr python=3.8
source activate t5qr
pip install -r requirements.txt

Data

Suppose that the current query is $q_3$ and the context is $[q_1, r_1, q_2, r_2]$, the input text sequence for T5 is $q_3$ [SEP] $r_2$ [SEP] $q_2$ [SEP] $r_1$ [SEP] $q_1$. We only include the last three responses at most. Note that the T5Tokenizer will additionally add an eos_token (i.e., </s>) to the end of the input text sequence. The target sequence is the oralce query.

We provide examples of training/dev file and test file in the data folder.

// trainig file:
// "cur_utt_text" is the current query. "oracle_utt_text" is the oracle query. "ctx_utts_text" is the previous user query context (i.e., [q_1, q_2, ...]) and "ctx_resps_text" is the previous agent responses context (i.e., [r_1, r_2, ...])

{"sample_id": "QReCC-Train_7625_2", "cur_utt_text": "When was the movie released", "oracle_utt_text": "When was the moview Amadeus released", "ctx_utts_text": ["how does mozart die in the movie amadeus"], "ctx_resps_text": ["Mozart suddenly came down with fever and was wracked with pain.In the following days his health significantly deteriorated. He died on December 5 after lapsing into a coma."]}


// test file:
// The test file format is the same as that of training file but does not need the "oracle_utt_text" field.

Training

bash run_train.sh

We randomly split the training set of QReCC to new training (90%) and dev(10%) sets and use them to train a T5 rewriter model.

You can set the maximum training epochs to num_train_epochs. But the early stopping will be triggered when the dev loss does not decresae for two consecutive periods of model saving. Therefore, you should use the third-to-last saved model as the final model.

Inference

bash run_inference.sh

We support using DDP for inference on multi GPUs. The corresponding rewrite is in the "t5_rewrite" field of the output file.

Note that: for QReCC, set:

max_response_length = 100
max_seq_length = 384

While for CAsT-19 and 20, set:

max_response_length=128 # As the **last** automatic canonical response in CAsT-20 is a longer passage compared with the responses in QReCC which are shorter text span.
max_seq_length=256  # CAsT-20 only include one response and CAsT-19 does not include response, so its maximum sequence length can be shorter than that of QReCC.

Quick Single Rewriting

We provide a "quick" script (i.e., single_rewrite.py) to perform a single rewriting. You need to set the rewriter path, cur_utt_text (i.e., the current user query), ctx_utts_text (i.e., the previous user queries), and ctx_resps_text(i.e., the previous agent responses) in the script, and then run:

python single_rewrite.py

Evaluation

After conversational query rewriting, we can evaluate the rewrites by directly comparing it with the manual oracle rewrites or on downstream ranking tasks. Currently, the evaluation is mainly conducted on QReCC test set, CAsT-19, and CAsT-20 test sets. We provide the evaluation instruction in evaluation folder.