qat 训练 模型convert_deploy的问题
wqn628 opened this issue · 2 comments
wqn628 commented
这里是我的模型定义:
class Net(torch.nn.Module):
def __init__():
pass
def forward(self, input, input_length):
x = self.input_layer(input)
x = self.tdnn(x)
chain_out = self.chain(x) ##(B,F,T',1),T'=(T-8)//3
chain_out = chain_out.transpose(1, 2).squeeze(3)
out_lengths = self.output_lengths(input_length)
return chain_out, out_lengths
def inference(self, input):
x = self.input_layer(input)
x = self.tdnn(x)
chain_out = self.chain(x)
return chain_out ###(B, F, T', 1);T' = (T-8//3)
训练 过程用的是forward 函数,在部署的时候我想用inference 函数推理。因此,我在训练代码中做了如下修改:
backend = BackendType.SNPE
model = prepare_by_platform(model, backend, prepare_custom_config_dict={"preserve_attr": {"": ["inference"]}})
enable_calibration(model)
model = model.to(device)
for epoch in range(10):
train_one_epoch()
model_name = 'mqbench_qmodel_{}'.format(epoch)
#input_shape={'input': [1, 1, 29, 80], "length": [1]}
input_shape={'input': [1, 1, 29, 80]}
#model.forward = model.inference
convert_deploy(model.eval(), backend, input_shape, output_path=model_dir, model_name=model_name)
然而当执行到convert_deploy(model.eval(), backend, input_shape, output_path=model_dir, model_name=model_name) 这个函数时,错误发生了:
File "xpspeech/cloud/bin/train_mmi-ctdnn.mq.py", line 334, in main
convert_deploy(model.eval(), backend, input_shape, output_path=model_dir, model_name=model_name)
File "/dataset/workspace/wangqingnan/asr_tool/cfm-mmi/mqbench/convert_deploy.py", line 192, in convert_deploy
deploy_model = deepcopy_graphmodule(model)
File "/dataset/workspace/wangqingnan/asr_tool/cfm-mmi/mqbench/utils/utils.py", line 73, in deepcopy_graphmodule
copied_gm = copy.deepcopy(gm)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/site-packages/torch/quantization/fx/graph_module.py", line 20, in __deepcopy__
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/dataset/workspace/miniconda3/envs/cfm-mq/lib/python3.8/site-packages/torch/utils/hooks.py", line 26, in __getstate__
return (self.hooks_dict_ref(), self.id)
RecursionError: maximum recursion depth exceeded while calling a Python object
函数convert_deploy的定义如下:
def convert_deploy(model: GraphModule, backend_type: BackendType,
input_shape_dict=None, dummy_input=None, output_path='./',
model_name='mqbench_qmodel', deploy_to_qlinear=False, **extra_kwargs):
r"""Convert model to onnx model and quantization params depends on backend.
Args:
model (GraphModule): GraphModule prepared qat module.
backend_type (BackendType): specific which backend should be converted to.
input_shape_dict (dict): keys are model input name(should be forward function
params name, values are list of tensor dims)
output_path (str, optional): path to save convert results. Defaults to './'.
model_name (str, optional): name of converted onnx model. Defaults to 'mqbench_qmodel'.
>>> note on input_shape_dict:
example: {'input_0': [1, 3, 224, 224]
'input_1': [1, 3, 112, 112]
}
while **_forward function_** signature is like:
def forward(self, input_0, input_1):
pass
"""
kwargs = {
'input_shape_dict': input_shape_dict,
'dummy_input': dummy_input,
'output_path': output_path,
'model_name': model_name,
'onnx_model_path': osp.join(output_path, '{}.onnx'.format(model_name)),
'deploy_to_qlinear': deploy_to_qlinear
}
kwargs.update(extra_kwargs)
deploy_model = deepcopy_graphmodule(model)
for convert_function in BACKEND_DEPLOY_FUNCTION[backend_type]:
convert_function(deploy_model, **kwargs)
想请教一下各位大佬,如果模型部署推理的时候 调的不是forward 函数而是其他函数的话,应该怎么做呀。
十分感谢,叨扰各位了
www516717402 commented
- Custom operate
model = Net()
src_forward_func = model.forward
model.forward = mode.inference
# convert operate
xxxxxx
# replace previous func
model.forward = src_forward_func
- QAT
- 在qat中上述方法不行,因为fx已经将py2py进行了转化,你只能去修改fx转化之后的model,比较麻烦
- 建议将训练和推理相同的部分进行整合用fx转化,后面不同的地方用2分支表示
github-actions commented
This issue has not received any updates in 120 days. Please reply to this issue if this still unresolved!