DayBreak-u/centernet_mobilenetv2_ncnn

转换成onnx的问题

lqian opened this issue · 0 comments

lqian commented

大神,您好,非常感谢分享工程,这个工程对学习centernet部署非常有帮助。在pytorch 1.2.0环境霞用mobilenetv2_10骨干网训练模型运行成功后,尝试把模型转换成onnx格式,遇到错误。pytorch新手,期望得到您的指点!
`
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.onnx as torch_onnx

import _init_paths

from detectors.detector_factory import detector_factory
from models.model import create_model, load_model
from opts import opts

heads = {'hm': 2, 'wh': 2, 'hps': 8, 'hm_hp': 4}
model = create_model('mobilenetv2_10', heads, 64)
model = load_model(model, '../model_best.pth')
torch.save(model, 'mobilnetv2_10.pth')

print(model)

input_shape = (3, 512, 384)
dump_input = Variable(torch.randn(1, *input_shape))
output = torch_onnx.export(model, dump_input, 'mobilnetv2_10.onnx', verbose=False)
print('Export of torch model completed!')
`

错误信息

Traceback (most recent call last): File "/train-data/CenterNet/src/tools/export_onnx.py", line 24, in <module> output = torch_onnx.export(model, dump_input, 'model.onnx', verbose=False) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/onnx/__init__.py", line 132, in export strip_doc_string, dynamic_axes) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/onnx/utils.py", line 64, in export example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/onnx/utils.py", line 329, in _export _retain_param_name, do_constant_folding) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/onnx/utils.py", line 213, in _model_to_graph graph, torch_out = _trace_and_get_graph_from_model(model, args, training) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/onnx/utils.py", line 171, in _trace_and_get_graph_from_model trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/jit/__init__.py", line 256, in get_trace_graph return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/link/.conda/envs/CenterNet/lib/python3.6/site-packages/torch/jit/__init__.py", line 324, in forward out_vars, _ = _flatten(out) RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got dict