/dssm

A BiGRU-Attention DSSM implementation with tensorflow estimator.

Primary LanguagePython

dssm

A BiGRU-Attention DSSM implementation with tensorflow estimator.

对应博客:https://blog.csdn.net/cdj0311/article/details/107634795

之前使用Keras和paddlepaddle实现过DSSM文本表示模型,(https://github.com/cdj0311/keras_bert_classification/blob/master/bert_dssm.py, https://github.com/cdj0311/paddledssm) 由于Keras做分布式计算比较麻烦,而paddlepaddle早已弃用,现在用tensorflow的高级API tf.estimator重写一遍,其中表示层使用双向GRU+Attention,最终输出为64维的向量。

python == 3.6

tensorflow == 1.13.1

训练步骤如下:

  1. 将文本数据转换为tfrecord格式:

    python convert_data.py

    data目录的data.txt中包含了10000条训练数据,数据为某新闻网站上的标题和对应的内容,格式为:title\tcontent,train.tfrecord是转换完成的tfrecord数据。

  2. 模型训练:

    sh train_local.sh

    模型训练完后会分别导出query和doc的pb格式模型,可根据需要进行选择。

  3. 模型预测:

    python predict.py

    给定一个句子得到向量,并获取最相似的N个句子,例如:

    输入: 赵丽颖冯绍峰在拍女儿国的时候真的超级甜了

    输出:

       0.801103	女神赵丽颖李沁都爱穿黄毛衣,但差距真的蛮大的
       0.744942	街拍:喜欢第二位俏皮可爱的小姐姐,和她在一起不会觉得无聊!
       0.722599	杜江霍思燕夫妇甜蜜现身 牵手依偎恩爱甜到发腻
       0.719018	还在情侣穿搭烦恼,看街拍情侣都是怎么搭配的
       0.707306	赵丽颖,应是绿肥红瘦,剧照
       0.701783	她的闺蜜则穿了一件白色的蕾丝连衣裙,尽显女人味
       0.70024	国民妖精十元女神可爱撩人瞬间合集!出色的不只是时尚穿搭
       0.691073	图集:#杨幂#赵丽颖暗斗时尚穿同款婚纱谁更美
       0.687201	赵丽颖 路人抓拍下的颖宝,这颜值可以说是完美的纯天然美女了~
    

    输入: 祝考研的女士们先生们都顺利考进自己理想的学校

    输出:

     0.890815	祝考研的女士们先生们都顺利考进自己理想的学校!实在考不上就滚tm的,当代...
     0.758741	硕士研究生招生考试22日开考
     0.701588	加油高考!祝你们顺利考上心仪的大学!
     0.660756	中考,你准备好了吗?
     0.654576	这些考研复试面试小技巧收好,导师的心就抓住了!
     0.63505	高考生作弊被抓飞踹监考老师:你知道我爸是谁?
     0.626651	高考倒计时30天,祝所有今年参加高考的小伙伴们心想事成,高考必胜
     0.590912	各位同学请注意,第一季期末考试现在开始~请认真阅读仔细答题
     0.585147	航班延误艺考生妈妈痛哭 浙传:可提供证明安排考试
     0.575564	当女儿带男同学回家写作业的时候,爸爸都在想什么
    
  4. 分布式训练

    设置run_on_cluster=True, 提交到job中即可训练,由于每个公司的分布式训练提交命令不一样,这里就不贴出来了。

该项目是基于字符做Embedding,实际使用中我们一般会将字和词同时作为输入进行训练。