Convert to onnx model
yanqi1811 opened this issue · 10 comments
Your work is very helpful for you, thank you! But when I try to convert this pytorch model to onnx file, I meet some errors. Have you tried this program? Thanks!
Yes I tried to trace the model before. At the moment it seems like the timm
module is not 100% compatible.
Will look into it in the future.
Thank you for your reply!
Yes I tried to trace the model before. At the moment it seems like the
timm
module is not 100% compatible.
Will look into it in the future.
Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion.
In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?
- The main problem is that the image input size can be dynamic but that doesn't play well with the tracing/scripting methods. It is not necessarily the fault of the
timm
module. If you have experience in this area, I'd love to hear some tips. Feel free to open a discussion - Initially, I was using both encoder and decoder from the
x-transformers
package but the performance was not very good. I used a pure ViT at the time (6ecc3f4).timm
offered some pre-built encoders with CNN backbones which increased the performance significantly.
it isn't problem of timm
,the encoder part
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
here are two problem, first, cat torch.zeros(1)(float)
and pos_emb_ind(long)
, second, one of the bug of the arange
.like this #1708 , and someone said he has fixed ityou can see here, and his branch has been merged , but i still have this problem with latest version pytorch.
Yes I tried to trace the model before. At the moment it seems like the
timm
module is not 100% compatible.
Will look into it in the future.Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion. In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?
Hi, I've traced the model for my deployment exercise.
First of all is the encoder
`
encoder.eval()
img = cv2.imread('path_to_my_example_image') #image size is 464 x 112
dummy_img = test_transform(image=img)['image'][:1].unsqueeze(0) #shape now is [1, 1, 112, 464]
with torch.no_grad():
torch.onnx.export(
encoder,
dummy_img,
f = "encoder.onnx",
opset_version=16,
input_names=['input_image'],
output_names=['output_context'],
dynamic_axes={
'context': {0: 'batch_size', 1: 'channel', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 1: 'output_context'}
},
export_params=True,
verbose=True
)
Everything is ok for the encoder. But when it comes to the decoder, here's the code:
decoder.eval()
dummy_context = encoder(dummy_img)
dummy_tgt_seq = torch.rand(1, 512)
dummy_input = {
'context': dummy_context,
'tgt_seq': dummy_tgt_seq
}
with torch.no_grad():
torch.onnx.export(
decoder,
args = (
dummy_input["tgt_seq"],
dummy_input["context"]
),
f = "decoder.onnx",
opset_version=16,
input_names=['input_seq', 'input_context'],
output_names=['output_seq'],
dynamic_axes={
'input_context': {0: 'batch_size', 1: 'sequence', '2': 'encoded_context'},
'output_seq': {0: 'batch_size', 1: 'output_seq'}
},
# export_params=True,
verbose=True
)
I've got the error like this:
Traceback (most recent call last):
File "d:\service-ml-api-server\flask_app\Im2Tex\pix2tex\convert_onnxmodel.py", line 105, in <module>
torch.onnx.export(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\__init__.py", line 350, in export
return utils.export(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 163, in export
_export(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 727, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 602, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 517, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
Is my solution got something went wrong or i have to do this in another way? Thanks.
I cant't Convert to onnx model too, have this issue been solved?
I have converted the image_resizer.pth
and weights.pth
to onnx format successfully, and I am organizing the inference code, please pay attention to this RapidLatexOCR repo
Hi all,
I currently convert model Latex-OCR to ONNX sucessfully.
The model's encoder and decoder are converted separately.
Details about the code via Code
If it's useful for your work please ⭐ my repo
@tranngocduvnvp Coincidentally, I also compiled a conversion code before, but currently there is an issue where dynamic dimensions cannot be inferred.
My code repo is ConvertLaTeXOCRToONNX.
Welcome to communicate together.