onnx/onnx-tensorflow

PyTorch - > ONNX -> TF

xkasberg opened this issue · 3 comments

Hello, I am converting a boiler plate Text Classification model from PyTorch to ONNX to Tensor Flow on Google Collab.

The conversion from PyTorch to ONNX works fine, but I get his trace back when converting to TensorFlow:

import os
import onnx
import torch
from onnx_tf.backend import prepare

model = torch.load('text_classification.pt')
onnx_model_path = 'text_classification.onnx'

torch.onnx.export(
    model, 
    args=(sample_input, torch.tensor([0])), 
    f=onnx_model_path,      
    opset_version=12,      
    input_names=['input1', 'input2'],  
    output_names=['output'],
    dynamic_axes={'input1': {0: 'batch'}},
)

model = onnx.load("text_classification.onnx")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

tf_rep = prepare(model)
tf_rep.export_graph('tf_model')

Traceback:

  File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/kristian/origami/model_conversion/convert_model.py", line 11, in <module>
    tf_rep.export_graph("tf_text_classification")
  File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_rep.py", line 143, in export_graph
    signatures=self.tf_module.__call__.get_concrete_function(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1239, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1219, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2523, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2760, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2670, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1247, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 3317, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1233, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1222, in autograph_handler
    return autograph.converted_call(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 30, in tf____call__
    ag__.for_stmt(ag__.ld(self).graph_def.node, None, loop_body, get_state, set_state, (), {'iterate_names': 'node'})
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 463, in for_stmt
    _py_for_stmt(iter_, extra_test, body, None, None)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 512, in _py_for_stmt
    body(target)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 478, in protected_body
    original_body(protected_iter)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 23, in loop_body
    output_ops = ag__.converted_call(ag__.ld(self).backend._onnx_node_to_tensorflow_op, (ag__.ld(onnx_node), ag__.ld(tensor_dict), ag__.ld(self).handlers), dict(opset=ag__.ld(self).opset, strict=ag__.ld(self).strict), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 62, in tf___onnx_node_to_tensorflow_op
    ag__.if_stmt(ag__.ld(handlers), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 56, in if_body_1
    ag__.if_stmt(ag__.ld(handler), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 48, in if_body
    retval_ = ag__.converted_call(ag__.ld(handler).handle, (ag__.ld(node),), dict(tensor_dict=ag__.ld(tensor_dict), strict=ag__.ld(strict)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 41, in tf__handle
    ag__.if_stmt(ag__.ld(ver_handle), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 33, in if_body
    retval_ = ag__.converted_call(ag__.ld(ver_handle), (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_fileap8kh2xb.py", line 12, in tf__version
    retval_ = ag__.converted_call(ag__.ld(cls)._common, (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 222, in tf___common
    ag__.if_stmt(ag__.ld(scan_outputs_start_index) == ag__.converted_call(ag__.ld(len), (ag__.ld(body).output,), None, fscope), if_body_3, else_body_3, get_state_7, set_state_7, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 210, in else_body_3
    scan_out_final = ag__.converted_call(ag__.ld(tf).cond, (ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(iter_cnt_final), 0), None, fscope), ag__.ld(true_fn), ag__.ld(false_fn)), None, fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
    return f(*args)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/ops/cond_v2.py", line 379, in _get_compatible_spec
    raise TypeError(f"No common supertype of {spec1} and {spec2}.")
TypeError: in user code:

    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_tf_module.py", line 99, in __call__  *
        output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op  *
        return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/handler.py", line 59, in handle  *
        return ver_handle(node, **kwargs)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 149, in version_11  *
        return cls._common(node, **kwargs)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 139, in _common  *
        scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)

    TypeError: No common supertype of TensorArraySpec(TensorShape([64]), tf.float32, True, True) and TensorArraySpec(TensorShape([64]), tf.float32, None, True).

Link to ONNX file: https://drive.google.com/file/d/1RxFM8igCEmN0lxolUd-DmHsmRFZZWsFQ/view?usp=sharing

I am on TensorFlow flow 2.9.

I was able to get past this issued with Tensorflow 2.8.3, but I cannot run inference on the TensorFlow model

tf_rep = prepare(model)
tf_rep.export_graph('tf_model')
model = tf.saved_model.load('tf_model')
input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
out = model(**{'input1': input_tensor,'input2': offset_tensor})

Traceback

WARNING:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-49-9e8ae2cb5c8d>](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in <module>
      4 input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
      5 offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
----> 6 out = model(**{'input1': input_tensor,'input2': offset_tensor})

2 frames
[/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in call(self, ctx, args, cancellation_manager)
    474     if len(args) != len(self.signature.input_arg):
    475       raise ValueError(
--> 476           f"Signature specifies {len(list(self.signature.input_arg))} "
    477           f"arguments, got: {len(args)}.")
    478 

ValueError: Signature specifies 6 arguments, got: 4.

Downgrading to Tensor Flow 2.8.3 and removing the TensorSpec Wrapper solved this issue