CLUE_NER_PyTorch
CLUENER 细粒度命名实体识别
数据介绍
数据详细描述: https://www.cluebenchmarks.com/introduce.html
代码目录说明
├── callback # 自定义常用callback
| └── lr_scheduler.py
| └── ...
├── losses # 自定义loss函数
| └── focal_loss.py
| └── ...
├── CLUEdatasets # 存放数据
| └── cluener
├── metrics # metric计算
| └── ner_metrics.py
├── outputs # 模型输出保存
| └── cluener_output
├── prev_trained_model # 预训练模型
| └── albert_base
| └── bert-wwm
| └── ...
├── processors # 数据处理
| └── ner_seq.py
| └── ...
├── tools # 通用脚本
| └── common.py
| └── download_clue_data.py
| └── ...
├── models # 主模型
| └── transformers
| └── bert_for_ner.py
| └── ...
├── run_ner_span.py # 主程序
├── run_ner_span.sh # 任务运行脚本
依赖模块
- pytorch=1.1.0
- boto3=1.9
- regex
- sacremoses
- sentencepiece
- python3.7+
运行方式
1. 下载预训练模型
链接:https://pan.baidu.com/s/1b7a-btBIHgaPBv3mJRkKnQ 密码:apqa
然后将bert-base.zip解压到prev_trained_model文件夹
2. 预训练模型文件格式,比如:
├── prev_trained_model # 预训练模型
| └── bert-base
| | └── vocab.txt
| | └── config.json
| | └── pytorch_model.bin
3. 训练:
直接执行对应shell脚本,如:
sh run_ner_span.sh
4. 预测
当前默认使用最后一个checkpoint模型作为预测模型,你也可以指定--predict_checkpoints参数进行对应的checkpoint进行预测,比如:
CURRENT_DIR=`pwd`
export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base
export GLUE_DIR=$CURRENT_DIR/CLUEdatasets
export OUTPUR_DIR=$CURRENT_DIR/outputs
TASK_NAME="cluener"
python run_ner_span.py \
--model_type=bert \
--model_name_or_path=$BERT_BASE_DIR \
--task_name=$TASK_NAME \
--do_predict \
--predict_checkpoints=100 \
--do_lower_case \
--loss_type=ce \
...
模型列表
model_type目前支持bert和albert
注意: bert ernie bert_wwm bert_wwwm_ext等模型只是权重不一样,而模型本身主体一样,因此参数model_type=bert其余同理。
输入编码方式
目前默认为BIOS编码方式,比如:
美 B-LOC
国 I-LOC
的 O
华 B-PER
莱 I-PER
士 I-PER
我 O
跟 O
他 O
谈 O
笑 O
风 O
生 O
结果
以下为模型在 dev上的测试结果:
Accuracy (entity) | Recall (entity) | F1 score (entity) | ||
---|---|---|---|---|
BERT+Softmax | 0.7916 | 0.7962 | 0.7939 | train_max_length=128 eval_max_length=512 epoch=4 lr=3e-5 batch_size=24 |
BERT+CRF | 0.7877 | 0.8008 | 0.7942 | train_max_length=128 eval_max_length=512 epoch=5 lr=3e-5 batch_size=24 |
BERT+Span | 0.8132 | 0.8092 | 0.8112 | train_max_length=128 eval_max_length=512 epoch=4 lr=3e-5 batch_size=24 |
BERT+Span+focal_loss | 0.8121 | 0.8008 | 0.8064 | train_max_length=128 eval_max_length=512 epoch=4 lr=3e-5 batch_size=24 loss_type=focal |
BERT+Span+label_smoothing | 0.8235 | 0.7946 | 0.8088 | train_max_length=128 eval_max_length=512 epoch=4 lr=3e-5 batch_size=24 loss_type=lsr |