lukas-blecher/LaTeX-OCR

Convert to onnx model

yanqi1811 opened this issue · 10 comments

Your work is very helpful for you, thank you! But when I try to convert this pytorch model to onnx file, I meet some errors. Have you tried this program? Thanks!

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

Thank you for your reply!

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion.
In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?

  1. The main problem is that the image input size can be dynamic but that doesn't play well with the tracing/scripting methods. It is not necessarily the fault of the timm module. If you have experience in this area, I'd love to hear some tips. Feel free to open a discussion
  2. Initially, I was using both encoder and decoder from the x-transformers package but the performance was not very good. I used a pure ViT at the time (6ecc3f4). timm offered some pre-built encoders with CNN backbones which increased the performance significantly.

it isn't problem of timm,the encoder part

        pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
        pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()

here are two problem, first, cat torch.zeros(1)(float) and pos_emb_ind(long), second, one of the bug of the arange .like this #1708 , and someone said he has fixed ityou can see here, and his branch has been merged , but i still have this problem with latest version pytorch.

Yes I tried to trace the model before. At the moment it seems like the timm module is not 100% compatible.
Will look into it in the future.

Hi, did you trace the whole model (encoder+decoder)? And what's your problem? Maybe we can have a discussion. In addition, I have a question and hope you can give some idea. [x_transformers] provides both encoder and decoder, why did you use the encoder from [timm] and the decoder from [x_transformers]. Is there any special reason?

Hi, I've traced the model for my deployment exercise.
First of all is the encoder
`

encoder.eval()
 img = cv2.imread('path_to_my_example_image') #image size is 464 x 112 
 dummy_img = test_transform(image=img)['image'][:1].unsqueeze(0) #shape now is [1, 1, 112, 464]

 with torch.no_grad():
     torch.onnx.export(
         encoder, 
         dummy_img, 
         f = "encoder.onnx",
         opset_version=16, 
         input_names=['input_image'], 
         output_names=['output_context'],
         dynamic_axes={
             'context': {0: 'batch_size', 1: 'channel', 2: 'height', 3: 'width'},
             'output': {0: 'batch_size', 1: 'output_context'}
             },
         export_params=True,
         verbose=True
     )

Everything is ok for the encoder. But when it comes to the decoder, here's the code:

decoder.eval()

dummy_context = encoder(dummy_img)
dummy_tgt_seq = torch.rand(1, 512)

dummy_input = {
    'context': dummy_context,
    'tgt_seq': dummy_tgt_seq
}

with torch.no_grad():
    torch.onnx.export(
        decoder,
        args = (
            dummy_input["tgt_seq"],
            dummy_input["context"]
        ),
        f = "decoder.onnx",
        opset_version=16,
        input_names=['input_seq', 'input_context'],
        output_names=['output_seq'],
        dynamic_axes={
            'input_context': {0: 'batch_size', 1: 'sequence', '2': 'encoded_context'},
            'output_seq': {0: 'batch_size', 1: 'output_seq'}
            },
        # export_params=True,
        verbose=True
    )

I've got the error like this:

Traceback (most recent call last):
  File "d:\service-ml-api-server\flask_app\Im2Tex\pix2tex\convert_onnxmodel.py", line 105, in <module>
    torch.onnx.export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\__init__.py", line 350, in export
    return utils.export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 163, in export
    _export(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\onnx\utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\jit\_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Admins\AppData\Local\Programs\Python\Python38\lib\site-packages\torch-1.12.1-py3.8-win-amd64.egg\torch\nn\modules\module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)

TypeError: forward() takes 2 positional arguments but 3 were given

Is my solution got something went wrong or i have to do this in another way? Thanks.

I cant't Convert to onnx model too, have this issue been solved?

SWHL commented

I have converted the image_resizer.pth and weights.pth to onnx format successfully, and I am organizing the inference code, please pay attention to this RapidLatexOCR repo

Hi all,
I currently convert model Latex-OCR to ONNX sucessfully.
The model's encoder and decoder are converted separately.
Details about the code via Code
If it's useful for your work please ⭐ my repo

SWHL commented

@tranngocduvnvp Coincidentally, I also compiled a conversion code before, but currently there is an issue where dynamic dimensions cannot be inferred.

My code repo is ConvertLaTeXOCRToONNX.

Welcome to communicate together.