
Primary LanguageJupyter Notebook

GMN Chatbot


This is a repository for the chatbot based on Graph Matching Network (GMN), which is our final project for the course: New Frontier Artificial Intelligence II at U-Tokyo.


We used the following English dialogues between patient and doctor to train our chatbot.



  • Please prepare the data by download the text files in the google drive link above. There should be the following files in the data directory
  • The preprocessing from text file to csv and generate the tokenized dataset files, please run preprocessing/reformat_text_data.py
python preprocessing/reformat_text_data.py
  • To get the (pytorch)dataloader, import from the following module. Note that the tokenization was done by huggingface bert-base-uncased tokenizer and the maximum length default is 256
from preprocessing.get_en_dataloader import get_training_dev_test_dataset

train_dataset, dev_dataset, test_dataset = get_training_dev_test_dataset(debugging=False, max_length=256)


which the dataset can be used in Huggingface trainer. In case of manual training, please wrap it with pytorch dataloader

from torch.utils.data import DataLoader

# For train loader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

Default Data split

Note that the data is split by the following

    - data/healthcaremagic_splitted_idname_1.csv
    - data/healthcaremagic_splitted_idname_2.csv
    - data/healthcaremagic_splitted_idname_3.csv

    - data/healthcaremagic_splitted_idname_4.csv

    - data/icliniq_splitted_idname.csv
  • The dataloader will return the following dictionary for each different index
    'input_ids': tensor(2, max_length),
    'token_type_ids': same,
    'attention_mask': same,
    'doctor_input_ids': tensor(#negative_sample + 1, max_length),
    'doctor_token_type_ids': same,
    'doctor_attention_mask': same,


There are 2 things to note here,

  • In all samples there are description, patient response and doctor response, in total 3-turns dialogue. So here, we treat the description and patient as first 2-turns dialogue and ask the model to output the probability of the third turn
  • the negative samples are sampled by randomly chosen from different response in other conversation. The correct response is always in the first index(0) and followed by #negative_sample number of wrong response.

BERT Baseline

Alt text