bert4keras是我最喜欢的库之一,但在现在来说其后端tf1.15显得有点落后,因此本库目的在实现对其的升级
兼容keras3及其对应后端 目前已经成功实现了bert4keras所支持的所有预训练模型的兼容
bert4keras实现的优化器目前暂时不做兼容,除开优化器部分外,如何使用请参考bert4keras的example,本仓库的example只提供了如何把模型load出来的测试
请参考api说明
因为我是个人开发,连草台班子都不是,经常会发布修改bug的版本,所以建议安装最新版本
pip install --upgrade bert4keras3
如果你用不是tensorflow后端,我建议安装一个tensorflow-cpu==2.10
当然如果你不需要加载旧版bert4keras对应的权重的话(对应下面模型权重的第一个表),那其实tf的cpu也不用安装。
pip3 install tensorflow-cpu==2.10
pip3 install --upgrade keras
如果你用torch后端,直接安装最新的torch就行了。但是我个人建议torch后端只用来调试
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip3 install keras
如果你需要使用tensorflow后端,那我建议你安装tensorflow的2.15
pip3 install tensorflow[and-cuda]==2.15
pip3 install --upgrade keras
当然你想安装最新的也可以,但是问题就是加载苏神的权重会有点问题。谷歌的尿性你们懂的
还有就是cuda版本要大于12.2,你的服务器不一定能同步。可以看tensorflow的cuda、cudnn版本对应
如果你想使用jax后端,jax安装建议看keras官方文档的jax-cuda要求
比如在keras3.3.3的情况下,官方推荐的版本是jax 0.4.23,那安装可以这么写
#cuda12
pip install -U "jax[cuda12]==0.4.23" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#cuda11
pip install -U "jax[cuda118]==0.4.23" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip3 install --upgrade keras
jax和tensorflow后端只能在linux使用cuda
初始版本与bert4keras基本相同,可以参考https://github.com/bojone/bert4keras
但需要注意的是,如果bert4keras的example中必须要要tf.keras的,在本库中依然需要
如果你需要使用tf之外的其他后端,需要修改bert4keras中的tf api
由于优化器部分维护工作量过大,本库放弃了对器优化器的维护。并且以后如果推出优化器功能,只keras3版本
目前keras3支持原生梯度累积、ema,AdamW等,如果需要什么keras不支持的功能欢迎提issue
除此之外重计算/gradient_checkpoint功能目前依然不支持keras3
如果你只是想兼容torch、tf和jax,那么我建议你使用纯keras的api实现,参考keras.io。对于精细的算子可以使用keras.ops,如果keras实在没有算子,那你只能提供一个api的三后端实现了
如果你想兼容keras2和tf.api,因为在keras3中增加了ops系列并且删除了绝大部分keras.backend中的算子操作。因此如果你需要兼容tf2是有一定困难的。
为了解决这个问题,bert4keras3.ops手动对齐了keras3中的ops,api。所以如果你想要兼容keras2和tf.keras,那么在编写代码时请from bert4keras3 import ops,在keras2中使用的是我们对齐的api,而在keras3中使用的是keras.ops。通过这种方法,你可以很容易地实现更好的兼容性
模型分类 | 模型名称 | 权重链接 | 支持kv-cache 生成 |
---|---|---|---|
bert/roberta | Google原版bert | github | √ |
brightmart版roberta | github | √ | |
哈工大版roberta | github | √ | |
追一开源bert | github | √ | |
LaBSE(多国语言BERT) | github | √ | |
albert | 谷歌albert | github | x |
brightmart版albert | github | x | |
苏神转换后的albert | github | x | |
NEZHA | 双向NEZHA | github | x |
单向NEZHA | github | x | |
T5 | 谷歌T5 | github | √ |
MT5 | github | √ | |
苏神T5-pegasus | github | √ | |
T5.1.1 | github | √ | |
ELECTRA | Google原版ELECTRA | github | x |
哈工大版ELECTRA | github | x | |
CLUE版ELECTRA | github | x | |
GPT-oai | GPT_OpenAI | github | x |
GPT2-ML | GPT2-ML | github | x |
GAU | GAU-ALPHA | github | x |
Roformer | 苏神原版roformer | github | √ |
roformer-sim | github | √ | |
Roformerv2 | 苏神原版roformer-v2 | github | √ |
模型分类 | 模型名称 | 权重链接 | 支持kv-cache 生成 | 数据类型 | 分词器 |
---|---|---|---|---|---|
RoformerV2 | RoformerV2-Small-CN | ModelScope | √ | FP32 | Tokenizer |
RoformerV2-Base-CN | ModelScope | √ | FP32 | Tokenizer | |
RoformerV2-Large-CN | ModelScope | √ | FP32 | Tokenizer | |
RoformerSim | RoformerSim-Small | ModelScope | √ | FP32 | Tokenizer |
RoformerSim-Ft-Small | ModelScope | √ | FP32 | Tokenizer | |
RoformerSim-Base | ModelScope | √ | FP32 | Tokenizer | |
RoformerSim-Ft_Base | ModelScope | √ | FP32 | Tokenizer | |
Derberta | Deberta_v3_Small_En | ModelScope | x | FP32 | SpTokenizer |
Deberta_v3_Base_En | ModelScope | x | FP32 | SpTokenizer | |
Deberta_v3_Large_En | ModelScope | x | FP32 | SpTokenizer | |
Deberta_v3_Base_Multi | ModelScope | x | FP32 | SpTokenizer |
模型分类 | 模型名称 | 权重链接 | 数据类型 | 分词器 |
---|---|---|---|---|
T5.1.1 | ChatYuan | ModelScope | FP32 | SpTokenizer |
Flan-T5-small | ModelScope | FP32 | SpTokenizer | |
Flan-T5-base | ModelScope | FP32 | SpTokenizer | |
Flan-T5-large | ModelScope | FP32 | SpTokenizer | |
Flan-T5-xl | ModelScope | FP32 | SpTokenizer | |
MT5-large | ModelScope | FP32 | SpTokenizer | |
UMT5-small | ModelScope | FP32 | SpTokenizer | |
UMT5-base | ModelScope | FP32 | SpTokenizer | |
UMT5-xl | ModelScope | FP32 | SpTokenizer | |
Gemma | Gemma-2b | ModelScope | BF16 | SpTokenizer |
Gemma-2b-Code | ModelScope | BF16 | SpTokenizer | |
Gemma-2b-it | ModelScope | BF16 | SpTokenizer | |
Gemma1.1-2b-it | ModelScope | BF16 | SpTokenizer | |
Gemma-7b | ModelScope | BF16 | SpTokenizer | |
Gemma-7b-Code | ModelScope | BF16 | SpTokenizer | |
Gemma-7b-it | ModelScope | BF16 | SpTokenizer | |
Gemma1.1-7b-it | ModelScope | BF16 | SpTokenizer | |
Gemma-7b-it-Code | ModelScope | BF16 | SpTokenizer | |
Gemma2 | Gemma2-2b | ModelScope | BF16 | AutoTokenizer |
Gemma2-2b-it | ModelScope | BF16 | AutoTokenizer | |
Gemma2-9b | ModelScope | BF16 | AutoTokenizer | |
Gemma2-9b-it | ModelScope | BF16 | AutoTokenizer | |
Gemma2-27b | 百度网盘 | BF16 | AutoTokenizer | |
Gemma2-27b-it | 百度网盘 | BF16 | AutoTokenizer | |
Yi | Yi-6B | ModelScope | BF16 | AutoTokenizer |
Yi-6B-it | ModelScope | BF16 | AutoTokenizer | |
Yi-9B | ModelScope | BF16 | AutoTokenizer | |
Yi-1.5-6B | ModelScope | BF16 | AutoTokenizer | |
Yi-1.5-9B | ModelScope | BF16 | AutoTokenizer | |
Yi-1.5-34B | 百度网盘 | BF16 | AutoTokenizer | |
Yi-1.5-34B-it | 百度网盘 | BF16 | AutoTokenizer | |
Llama | Llama3-8B | ModelScope | BF16 | AutoTokenizer |
Llama3-8B-it | ModelScope | BF16 | AutoTokenizer | |
Llama3.1-8B | ModelScope | BF16 | AutoTokenizer | |
Llama3.1-8B-it | ModelScope | BF16 | AutoTokenizer | |
千问 | Qwen-0.5B | ModelScope | BF16 | AutoTokenizer |
Qwen-0.5B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen-1.8B | ModelScope | BF16 | AutoTokenizer | |
Qwen-1.8B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen-4B | ModelScope | BF16 | AutoTokenizer | |
Qwen-4B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen-7B | ModelScope | BF16 | AutoTokenizer | |
Qwen-7B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen-14B | ModelScope | BF16 | AutoTokenizer | |
千问2 | Qwen2-0.5B | ModelScope | BF16 | AutoTokenizer |
Qwen2-0.5B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2-1.5B | ModelScope | BF16 | AutoTokenizer | |
Qwen2-1.5B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2-7B | ModelScope | BF16 | AutoTokenizer | |
Qwen2-7B-it | ModelScope | BF16 | AutoTokenizer | |
千问2.5 | Qwen2.5-0.5B | ModelScope | BF16 | AutoTokenizer |
Qwen2.5-0.5B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-1.5B | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-1.5B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-3B | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-3B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-7B | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-7B-it | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-14B | ModelScope | BF16 | AutoTokenizer | |
Qwen2.5-14B-it | ModelScope | BF16 | AutoTokenizer | |
RWKV6 | RWKV6-1.6B | ModelScope | BF16 | RWKV_TOKENIZER |
RWKV6-3B | ModelScope | BF16 | RWKV_TOKENIZER | |
RWKV6-7B | ModelScope | BF16 | RWKV_TOKENIZER | |
RWKV6-12B-it | ModelScope | BF16 | RWKV_TOKENIZER | |
RWKV6-14B | ModelScope | BF16 | RWKV_TOKENIZER |
注意事项
- 注1:brightmart版albert的开源时间早于Google版albert,这导致早期brightmart版albert的权重与Google版的不完全一致,换言之两者不能直接相互替换。为了减少代码冗余,bert4keras的0.2.4及后续版本均只支持加载Google版以brightmart版中带Google字眼的权重。如果要加载早期版本的权重,请用0.2.3版本,或者考虑作者转换过的albert_zh。(苏神注)
- 注2:下载下来的ELECTRA权重,如果没有json配置文件的话,参考这里自己改一个(需要加上
type_vocab_size
字段)。(苏神注) - 注3: 模型分类这里会跳转到使用的example
- 注4:SpTokenizer和RWKV_TOKENIZER来自bert4keras3.tokenizers,AutoTokenizer指的是transformers的分词器。用法不同需要注意
- 注5:因为不能转换全部的权重,所以我提供了转化权重的脚本,有需要自己去转。
- 注6:bert4keras3的新增加的模型权重均支持kv-cache生成
- 注7: it模型指的是instruct模型,也就是我们俗话说的chat模型
对bert4keras除优化器部分外的升级,实现对tensorflow,jax,torch的多后端兼容
转换了chatyuan模型权重(基于t5模型)
更新了支持批量运算的t5-cache推理版本,详细使用参考t5-cache的使用example 。里面较为详细地列出了cache模型要如何使用。
除了T5,还增加了bert和
roformer/roformer-v2的cache支持,用法和t5一样,example里只是测试一下与greedy是否一致
增加了对weights.h5的读取支持
增加了lora支持,可以通过设置os.environ["ENABLE_LORA"]='1' 启动lora训练,注意的是除了lora之外的参数全部会被冻结
增加了flash-attention支持,可以通过设置os.environ["FLASH_ATTN"]='1'使用flash-attention
但是需要注意的是,tensorflow不支持。而jax在https://github.com/nshepperd/flash_attn_jax/releases 下载,torch则是 https://github.com/Dao-AILab/flash-attention
重新整理了苏神的代码,更新了对 Gemma,Qwen,和llama系列模型(llama3和Yi)的支持,转换了UMT5,FlanT5的权重,并且提供了转换脚本,大家可以自行转换权重
新版本可以在build_transformer_model时添加penalty,penalty_window ,max_penalty_range,temperature四个参数。
详情可以参考文档。
增加了RWKV6的层及其模型的支持,关于层的详细介绍可以查看文档RWKV_layer.
对于RWKV6更详细的介绍,我们单独创建了一个RWKV6仓库,在这里你可以看到关于本库对RWKV6的详细介绍
增加了llama3.1和gemma2的支持
从keras nlp那里韩了deberta的支持,同时增加了qwen2.5的支持。并且把roformer v2转为了keras3的权重。