lonePatient/NeZha_Chinese_PyTorch

咨询下finetune的模型大小

renjunxiang opened this issue · 6 comments

你好,我运行HuggingFace的run_mlm.py脚本做finetune,nezha的base模型保存pytorch_model.bin有1G多,请问这是什么原因?
我对比运行了华为官方的modeling_nezha,发现你们的区别是在attention层,你的多了relative_positions_encoding,是因为这个吗?

@renjunxiang 不是这个,bin文件应该没这么大,保存的模型目录可能有这么大,应该huggingFace保存了optimizer数据

@renjunxiang 不是这个,bin文件应该没这么大,保存的模型目录可能有这么大,应该huggingFace保存了optimizer数据

我们跑了BERT、RoBERTa和NeZha。
BERT、RoBERTa的pytorch_model.bin都是400M左右,optimizer.pt是800M左右。
NeZha的pytorch_model.bin是1.2G左右,optimizer.pt是800M左右。
另外为了验证模型,分别尝试了你的和官方的,载入后再保存。

# https://github.com/lonePatient/NeZha_Chinese_PyTorch/
from nezha.modeling_nezha import NeZhaModel
BERT = NeZhaModel.from_pretrained('./nezha-cn-base')
torch.save(BERT.state_dict(),'./nezha.pth')

保存后的模型大小未改变,还是1.2G左右。

# 官方https://github.com/huawei-noah/Pretrained-Language-Model
from nezha.modeling_nezha import BertModel
BERT = BertModel.from_pretrained('./nezha-cn-base')
torch.save(BERT.state_dict(),'./nezha.pth')

保存后的模型大小变为400M左右。

不知道是什么原因,两个载入方式推断的结果是一致的,是不是你的保存了更多的?或者你的pytorch_model.bin本身就集成了optimizer?我和其他朋友确认过,pytorch_model.bin都是1G以上。还是说,我们运行HuggingFace的run_mlm.py脚本修改的不对?

from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)

修改为

from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    BertTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from nezha.configuration_nezha import NeZhaConfig
from nezha.modeling_nezha import NeZhaForMaskedLM

感谢您的解答~

@renjunxiang 非常感谢,我的问题,主要是保存了relative_positions_encoding数据造成的。

@renjunxiang 非常感谢,我的问题,主要是保存了relative_positions_encoding数据造成的。

感谢您的解答!经过验证,RelativePositionsEncoding不用nn.Module的方式即可,保存的模型为400M。

@renjunxiang 非常感谢,我的问题,主要是保存了relative_positions_encoding数据造成的。

感谢您的解答!经过验证,RelativePositionsEncoding不用nn.Module的方式即可,保存的模型为400M。

请问,不用nn.Module,用什么呢?

@liucongg 用函数形式实现,代码已经更新。