Authors: Chien-Sheng Wu, Steven Hoi, Richard Socher and Caiming Xiong.
EMNLP 2020. Paper: https://arxiv.org/abs/2004.06871
The underlying difference of linguistic patterns between general text and task-oriented dialogue makes existing pre-trained language models less useful in practice. In this work, we unify nine human-human and multi-turn task-oriented dialogue datasets for language modeling. To better model dialogue behavior during pre-training, we incorporate user and system tokens into the masked language modeling. We propose a contrastive objective function to simulate the response selection task. Our pre-trained task-oriented dialogue BERT (TOD-BERT) outperforms strong baselines like BERT on four downstream task-oriented dialogue applications, including intention recognition, dialogue state tracking, dialogue act prediction, and response selection. We also show that TOD-BERT has a stronger few-shot ability that can mitigate the data scarcity problem for task-oriented dialogue.
If you use any source codes, pretrained models or datasets included in this repo in your work, please cite the following paper. The bibtex is listed below:
@article{wu2020tod, title={ToD-BERT: Pre-trained Natural Language Understanding for Task-Oriented Dialogues}, author={Wu, Chien-Sheng and Hoi, Steven and Socher, Richard and Xiong, Caiming}, journal={arXiv preprint arXiv:2004.06871}, year={2020} }
- (2020.10.01) More training and inference information added. Release TOD-DistilBERT.
- (2020.07.10) Loading model from Huggingface is now supported.
- (2020.04.26) Pre-trained models are available.
You can easily load the pre-trained model using huggingface Transformers library using the AutoModel function. Several pre-trained versions are supported:
- TODBERT/TOD-BERT-MLM-V1: TOD-BERT pre-trained only using the MLM objective
- TODBERT/TOD-BERT-JNT-V1: TOD-BERT pre-trained using both the MLM and RCL objectives
- TODBERT/TOD-DistilBERT-JNT-V1: TOD-DistilBERT pre-trained using both the MLM and RCL objectives
import torch
from transformers import *
tokenizer = AutoTokenizer.from_pretrained("TODBERT/TOD-BERT-JNT-V1")
tod_bert = AutoModel.from_pretrained("TODBERT/TOD-BERT-JNT-V1")
You can also downloaded the pre-trained models from the following links:
model_name_or_path = <path_to_the_downloaded_tod-bert>
model_class, tokenizer_class, config_class = BertModel, BertTokenizer, BertConfig
tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
tod_bert = model_class.from_pretrained(model_name_or_path)
Please refer to the following guide how to use our pre-trained ToD-BERT models. Our model is built on top of the PyTorch library and huggingface Transformers library. Let's do a very quick overview of the model architecture and code. Detailed examples for model architecturecan be found in the paper.
# Encode text
input_text = "[CLS] [SYS] Hello, what can I help with you today? [USR] Find me a cheap restaurant nearby the north town."
input_tokens = tokenizer.tokenize(input_text)
story = torch.Tensor(tokenizer.convert_tokens_to_ids(input_tokens)).long()
if len(story.size()) == 1:
story = story.unsqueeze(0) # batch size dimension
if torch.cuda.is_available():
tod_bert = tod_bert.cuda()
story = story.cuda()
with torch.no_grad():
input_context = {"input_ids": story, "attention_mask": (story > 0).long()}
hiddens = tod_bert(**input_context)[0]
If you would like to train the model yourself, you can download those datasets yourself from each of their original papers or sources. You can also direct download a zip file here.
The repository is currently in this structure:
.
└── image
└── ...
└── models
└── multi_class_classifier.py
└── multi_label_classifier.py
└── BERT_DST_Picklist.py
└── dual_encoder_ranking.py
└── utils.py
└── multiwoz
└── ...
└── metrics
└── ...
└── loss_function
└── ...
└── dataloader_nlu.py
└── dataloader_dst.py
└── dataloader_dm.py
└── dataloader_nlg.py
└── dataloader_usdl.py
└── ...
└── README.md
└── evaluation_pipeline.sh
└── evaluation_ratio_pipeline.sh
└── run_tod_lm_pretraining.sh
└── main.py
└── my_tod_pretraining.py
- Run Pretraining
❱❱❱ ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-MLM --only_last_turn
❱❱❱ ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-JNT --only_last_turn --add_rs_loss
- Run Fine-tuning
❱❱❱ ./evaluation_pipeline.sh 0 bert bert-base-uncased save/BERT
- Run Fine-tuning (Few-Shot)
❱❱❱ ./evaluation_ratio_pipeline.sh 0 bert bert-base-uncased save/BERT --nb_runs=3
Feel free to create an issue or send email to the first author at cswu@salesforce.com