Fine-tuning Chinese BERT model based with transformers Trainer api
目标是实现 sentence-Transformer 论文中的模型(BertForSiameseNetwork)结果,并进行训练。
参考Transformer的Trainer的基本用法Training and fine-tuning 中naive PyTorch的基本流程。构造 model, dataset 等对象即可。
目标:实现基于BERT的SiameseNetwork的Model以及Dataset。
- 读取所有样构造成 InputExample,并保存为List 对象
- 实现Dataset的类:SiameseDataset,输入:examples, tokenizer,实现方法:getitem:实现分词并返回token_id 等序列。
- 实现collate_fn(整理样本的函数): 将样本的List对象初始化为DataLoader(回调),对于每一个batch的数据进行处理(如:query/candidates 长度对齐,并构造自定义Model的输入:)
- 实现自定义Model: BertForSiameseNetWork, 需要继承类 BertPreTrainedModel, 类BertPreTrainedModel,使用BertModel 作为encoder;实现forward的函数:自定义输入特征函数,使用encoder 获得 query/candidates 的句向量,计算余弦相似度作为loss 返回。
- 实现 compute_metric, 进行指标计算
- 初始化Trainer 实例,输入BertForSiameseNetwork, SiameseDataset, collate_fn以及compute_metrics。