/BertClassifier

基于PyTorch的BERT中文文本分类模型(BERT Chinese text classification model implemented by PyTorch)

Primary LanguagePythonMIT LicenseMIT

BertClassifier

(求Star⭐)本项目仅仅提供了最基础的BERT文本分类模型,代码是作者在入门NLP时自己写的,对于初学者还算比较好理解,细节上有不足的地方,大家可以自行修改。读者在使用的时候有任何问题和建议都可以通过邮件联系我。


本文利用了transformers中的BertModel,对部分cnews数据集进行了文本分类,在验证集上的最优Acc达到了0.92,拿来对BERT模型练手还是不错的。

数据描述

数据集是从清华大学的THUCNews中提取出来的部分数据。

训练集中有5万条数据,分成了10类,每类5000条数据。

{"体育": 5000, "娱乐": 5000, "家居": 5000, "房产": 5000, "教育": 5000, "时尚": 5000, "时政": 5000, "游戏": 5000, "科技": 5000, "财经": 5000}

验证集中有5000条数据,每类500条数据。

{"体育": 500, "娱乐": 500, "家居": 500, "房产": 500, "教育": 500, "时尚": 500, "时政": 500, "游戏": 500, "科技": 500, "财经": 500}

如果需要数据集,请与我联系.

数据集放在了百度网盘上:链接: https://pan.baidu.com/s/1FVV8fq7vSuGSiOVnE4E_Ag 提取码: bbwv

模型描述

整个分类模型首先把句子输入到Bert预训练模型,然后将句子的embedding(CLS位置的Pooled output)输入给一个Linear,最后把Linear的输出输入到softmax中。

Figure 1: Model

环境

硬件 环境
GPU GTX1080
RAM 64G
软件 环境
OS Ubuntu 18.04 LTS
CUDA 10.2
PyTorch 1.6.0
transformers 3.2.0

结果

分类报告:

* Classification Report:                                                                            
              precision    recall  f1-score   support                                               
                                                                                                    
          体育       1.00      0.99      0.99       500                                             
          娱乐       0.99      0.92      0.96       500                                             
          家居       0.96      0.73      0.83       500                                             
          房产       0.83      0.94      0.88       500                                             
          教育       0.94      0.75      0.84       500                                             
          时尚       0.89      0.99      0.94       500                                             
          时政       0.91      0.96      0.93       500                                             
          游戏       0.93      0.98      0.96       500                                             
          科技       0.91      0.96      0.93       500                                             
          财经       0.87      0.98      0.92       500                                             
                                                                                                    
    accuracy                           0.92      5000                                               
   macro avg       0.92      0.92      0.92      5000                                               
weighted avg       0.92      0.92      0.92      5000

使用方法:

创建data文件夹,把下载好的cnews数据集放在data文件夹下。

创建models文件夹,用来保存模型

安装相应依赖库: pip install -r requirements.txt

训练: python train.py

预测: python predict.py

Star History

Star History Chart