This is a multi-turn chatbot project using the pre-trained GPT-2[1] introduced in How to build a State-of-the-Art Conversational AI with Transfer Learning[2].

Especially, this repository uses the GPT-2 Language Modeling Head model which has one additional linear layer to conduct Language Modeling task to consider the dialogue contexts and make a proper next response.

I did not include the persona information unlike the original version.


Arguments for data loading

Argument Type Description Default
data_dir str The name of the parent directory where data files are stored. "data"
train_prefix str The prefix of the train data files' name. "train"
valid_prefix str The prefix of the validation data files' name. "valid"
train_frac float The ratio of the conversations to be included in the train set. 0.85
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"

Arguments for training

Argument Type Description Default
seed int The random seed. 0
data_dir str The name of the parent directory where data files are stored. "data"
train_prefix str The prefix of the train data files' name. "train"
valid_prefix str The prefix of the validation data files' name. "valid"
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"
bos_token str The BOS token. "<bos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
gpu str The index of GPU to use. "0"
lr float The learning rate. 2e-5
warmup_ratio float The ratio of warmup steps to the total training steps. 0.1
batch_size int The batch size. 8
num_workers int The number of workers for data loading. 0
num_epochs int The number of total epochs. 10
max_len int The maximum length of input sequence. 1024
max_turns int The maximum number of dialogue histories to include. 5
ckpt_dir str The path for saved checkpoints. "saved_models"
ckpt_name str The default name for the trained model. (without extension) YOU MIGHT SPECIFY

Arguments for inference

Argument Type Description Default
seed int The random seed. 0
data_dir str The name of the parent directory where data files are stored. "data"
model_type str The model type of GPT-2. ("gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl") "gpt2"
bos_token str The BOS token. "<bos>"
sp1_token str The speaker1 token. "<sp1>"
sp2_token str The speaker2 token. "<sp2>"
gpu str The index of GPU to use. "0"
max_len int The maximum length of input sequence. 1024
max_turns int The maximum number of dialogue histories to include. 5
top_p float The top-p value for nucleus sampling decoding. 0.8
ckpt_dir str The path for saved checkpoints. "saved_models"
ckpt_name str The default name for the trained model. (without extension) YOU SHOULD SPECIFY
end_command str The command to stop the conversation when inferencing. "Abort!"


By default, I propose the codes for downloading the datasets and preprocessing.

There are 4 types of the default datasets as follows.

  • DailyDialog[3]
  • EmpatheticDialogues[4]
  • Persona-Chat[5]
  • BlendedSkillTalk[6]

How to run

  1. Install all required packages.

    pip install -r requirements.txt

  2. Download & Preprocess all datasets.

    sh exec_load_data.sh

    After running it, you will have the following data directory structure if you follow the default argument setting.


  3. Run the following command to train the model.

    If you want to train it starting from a specific checkpoint, add the argument ckpt_name and make sure to notify the proper checkpoint name.

    sh exec_train.sh

  4. Run below command to conduct an inference with the trained model.

    This time, you are required to give a specific ckpt_name.

    sh exec_infer.sh


