Implements the model described in the following paper Do Response Selection Models Really Know What's Next? Utterance Manipulation Strategies for Multi-turn Response Selection .
@inproceedings{whang2021ums,
title={Do Response Selection Models Really Know What's Next? Utterance Manipulation Strategies for Multi-turn Response Selection},
author={Whang, Taesun and Lee, Dongyub and Oh, Dongsuk and Lee, Chanhee and Han, Kijong and Lee, Dong-hun and Lee, Saebyeok},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
year={2021}
}
This code is reimplemented as a fork of huggingface/transformers and taesunwhang/BERT-ResSel.
This code is implemented using PyTorch v1.6.0, and provides out of the box support with CUDA 10.1 and CuDNN 7.6.5.
Anaconda / Miniconda is the recommended to set up this codebase.
Clone this repository and create an environment:
git clone https://www.github.com/taesunwhang/UMS-ResSel
conda create -n ums_ressel python=3.7
# activate the environment and install all dependencies
conda activate ums_ressel
cd UMS-ResSel
# https://pytorch.org
pip install torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
We provide following pre- and post-trained checkpoints.
- bert-base (english), bert-base-wwm (chinese)
- bert-post (ubuntu, douban, e-commerce)
- electra-base (english), electra-base (chinese)
- electra-post (ubuntu, douban, e-commerce)
sh scripts/download_pretrained_checkpoints.sh
Original version for each dataset is availble in Ubuntu Corpus V1, Douban Corpus, and E-Commerce Corpus, respectively.
sh scripts/download_datasets.sh
#Ubuntu Corpus V1
sh scripts/create_bert_post_data_creation_ubuntu.sh
#Douban Corpus
sh scripts/create_bert_post_data_creation_douban.sh
#E-commerce Corpus
sh scripts/create_bert_post_data_creation_e-commerce.sh
sh scripts/download_electra_post_training_pkl.sh
python3 main.py --model bert_post_training --task_name ubuntu --data_dir data/ubuntu_corpus_v1 --bert_pretrained bert-base-uncased --bert_checkpoint_path bert-base-uncased-pytorch_model.bin --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir --training_type post_training
python3 main.py --model electra_post_training --task_name douban --data_dir data/electra_post_training --bert_pretrained electra-base-chinese --bert_checkpoint_path electra-base-chinese-pytorch_model.bin --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir --training_type post_training
task_name | data_dir | bert_pretrained | bert_checkpoint_path |
---|---|---|---|
ubuntu | data/ubuntu_corpus_v1 | bert-base-uncased | bert-base-uncased-pytorch_model.bin |
douban e-commerce |
data/douban data/e-commerce |
bert-base-wwm-chinese | bert-base-wwm-chinese_model.bin |
task_name | data_dir | bert_pretrained | bert_checkpoint_path |
---|---|---|---|
ubuntu | data/ubuntu_corpus_v1 | bert-post-uncased | bert-post-uncased-pytorch_model.pth |
douban | data/douban | bert-post-douban | bert-post-douban-pytorch_model.pth |
e-commerce | data/e-commerce | bert-post-ecommerce | bert-post-ecommerce-pytorch_model.pth |
task_name | data_dir | bert_pretrained | bert_checkpoint_path |
---|---|---|---|
ubuntu | data/ubuntu_corpus_v1 | electra-base | electra-base-pytorch_model.bin |
douban e-commerce |
data/douban data/e-commerce |
electra-base-chinese | electra-base-chinese-pytorch_model.bin |
task_name | data_dir | bert_pretrained | bert_checkpoint_path |
---|---|---|---|
ubuntu | data/ubuntu_corpus_v1 | electra-post | electra-post-pytorch_model.pth |
douban | data/douban | electra-post-douban | electra-post-douban-pytorch_model.pth |
e-commerce | data/e-commerce | electra-post-ecommerce | electra-post-ecommerce-pytorch_model.pth |
python3 main.py --model bert_post --task_name ubuntu --data_dir data/ubuntu_corpus_v1 --bert_pretrained bert-post-uncased --bert_checkpoint_path bert-post-uncased-pytorch_model.pth --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir
python3 main.py --model bert_post --task_name douban --data_dir data/douban --bert_pretrained bert-post-douban --bert_checkpoint_path bert-post-douban-pytorch_model.pth --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir --multi_task_type "ins,del,srch"
python3 main.py --model electra_base --task_name e-commerce --data_dir data/e-commerce --bert_pretrained electra-base-chinese --bert_checkpoint_path electra-base-chinese-pytorch_model.bin --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir --multi_task_type "ins,del,srch"
To evaluate the model, set --evaluate
to /path/to/checkpoints
python3 main.py --model bert_post --task_name ubuntu --data_dir data/ubuntu_corpus_v1 --bert_pretrained bert-post-uncased --bert_checkpoint_path bert-post-uncased-pytorch_model.pth --task_type response_selection --gpu_ids "0" --root_dir /path/to/root_dir --evaluate /path/to/checkpoints --multi_task_type "ins,del,srch"
We provide model checkpoints of UMS-BERT+, which obtained new state-of-the-art, for each dataset.
Ubuntu | R@1 | R@2 | R@5 |
---|---|---|---|
UMS-BERT+ | 0.875 | 0.942 | 0.988 |
Douban | MAP | MRR | P@1 | R@1 | R@2 | R@5 |
---|---|---|---|---|---|---|
UMS-BERT+ | 0.625 | 0.664 | 0.499 | 0.318 | 0.482 | 0.858 |
E-Commerce | R@1 | R@2 | R@5 |
---|---|---|---|
UMS-BERT+ | 0.762 | 0.905 | 0.986 |