/Poetry_Generate_PyTorch

学习用PyTorch创作唐诗

Primary LanguagePython

Poetry_Generate_PyTorch

学习用PyTorch创作唐诗

项目简介

最近在学习PyTorch和BERT,当然用唐诗生成作为NLP新手任务再好不过了~之前写了TensorFlow和Keras版本的,代码写的比较烂,借此机会也提高下自己的工程能力。

模块结构

结构比较简单,包括数据、预处理方法、网络模型、训练代码和测试代码:

  • 数据:在data文件夹中,解压Tang_Poetry.zip得到.txt文档,来源https://github.com/todototry/AncientChinesePoemsDB
  • 预处理:ProcessData.py,用于合并、清洗每个txt文档,文本转编码、填充切片
  • 网络结构:在文件夹net中,mynet.py、dataset.py分别是网络结构和pytorch的dataset读取方式,包含了2层lstm、BERT和BERT+2层lstm三种方式。
  • 训练模型:train.py,训练网络,'RNN', 'BERT', 'BERT_RNN'三种方式
  • 模型存储:model文件夹,保存训练的模型
  • 生成唐诗:test.py,'RNN', 'BERT', 'BERT_RNN'三种方式,生成诗歌

其他说明

  • 按照传统的Embedding+LSTM+Softmax,效果还可以。
  • BERT是通过flair模块调用的,训练过程因为会逐个计算BERT输出,速度非常慢,内存足够的话建议先批量转为[样本数, 句子长度, BERT输出维度]的tensor。
  • 使用BERT发现效果非常差,学不到唐诗的结构。原因猜测是训练和推断都是从左向右,应当采用Transformer中Decoder部分的下三角矩阵Multi-headed attention,而BERT采用的是Transformer中Encoder部分。在推断过程中,由于不知道未生成部分的文本,应该采用Future blinding的方式,而不是生成一个字就全局重新计算权重。flair官方推荐的是用BERT后四层输出拼接成[batchsize, len, 786*4]的方式,实际使用发现只用最后一层的效果略好。

成果展示