
A chinese chitchat model based on GPT-2 and DialoGPT which supports multi round chichta

Primary LanguagePython



  • Intern Recruitment Test

    • A chinese chitchat demo based on DialGPT which supports multi-round chichat
    • Time Limit: a week
  • Obstacles:With no NLP background but basic knowledge on DeepLearning

  • Achievements

    • Outlined the evolution and structures of NLP models, and dived deep in CDial-GPT and MMI model.

      RNN/LSTM/GRU -> Seq2Seq -> Seq2Seq with Attention -> Transformer -> GPT2

    • Learned Huggingface for pretrained model fintuning

    • Developed the entire pipeline for this NLP task, including training and inference.


  • Target: A multi-round chitchat chabot concerning dailogue history

    • Inputs: combination of dialogue history and current round input from users
    • Output: candidate response and corresponding loss
  • KnowledgeMap

  • Final Results

    Though not able to answer with the best candidate sometimes, the chatbot is already capable of generating some history-related candidates.

    wecom-temp-56851-e223ff049478c7b9e2f435e5145a7216 image-20231216065804695


  • Model

    • For candidate generate: CDial-GPT2_LCCC-base


    • For save answer decay: MMI model from DialGPT

      The more specific the answers are the more weights they get.

  • DatasetSTC数据集


  • Optimizer:AdamW

  • WarmUp:Linear Schedule

  • DecodingStrategy:temperature + top Sampling

    Rerank candidate responses based on MMI score.

Training Process

  • DatasetSize: Didn't finish to run the large dataset in this short duration, only 2.98M dataset was applied for fintuning

  • Epochs: 3 for both CDial-GPT2 and MMI model

  • Accuracy: CDial-GPT2 50%~60%, MMI model 50%~65%

    Due to time constraints, most of the time was dedicated to succesffully run the training and inference process, instead of improving the accuracy.

Program Design


Step1.1 DataLoader

  • 流程

  • Learning: how to load large dataset

    with open(path, r, encoding='utf-8') as f:
    	dataset = json.loads(f.read())

Step1.2 DatasetPreprosessor

  • pipeline

  • inputs for CDial-GPT2

    • input_ids: [CLS] question1 [SEP] answer1 [SEP] question2 [SEP] answer2 [SEP]
    • token_type: [speaker1] for questions, [speaker2] for answers

Step2 TrainingLoop

  • loss calculation: use token embeddings from output in N-1 position for predict

Step3 Evaluation

  • model.eval()模式下,借助测试集对模型进行评估


  1. Get candidate_response by CDial-GPT2

    • Difficulty: candidate reponses vary in length. need to design batch-process them elegantly.

  2. Get the most specifc and relevant answer by MMI

    • MMI inputs: concatenate candidate response and dailogue history in reverse order

    • Output: the response with minimum loss

    • Action: output the answer and add it to the history


  1. 在已有的checkpoints上,用完整的训练集对两个模型进行训练(增加epoch)

  2. 当前只实现了多轮对话,并没有考虑上下文的指代关系。后续可以考虑使用动态神经网络(传递推理,解决指代关系)

  3. 改变编码方式

    • 将当前模型的定长编码换成NEZHA的相对位置编码,能接受更长的句子输入
    • UNLM模型:改变mask编码:不预测问句部分,只预测答句部分


p.s not a full-time project but a project implemented after cources on day

Theories Learning | 11.21-11.22



Study on existed research | 11.23

Programming | 11.24-11.26

  • Studied source code of GPT2 for Chinese chitchat
  • 11.24
    • Training loop
    • Dataset Loader
  • 11.25
    • Training loop
    • Customized dataset preprocess, improving process on token2id
  • 11.26
    • Inference