onnx/onnx-tensorflow

PyTorch - > ONNX -> TF

xkasberg opened this issue · 3 comments

Hello, I am converting a boiler plate Text Classification model from PyTorch to ONNX to Tensor Flow on Google Collab.

The conversion from PyTorch to ONNX works fine, but I get his trace back when converting to TensorFlow:

import os
import onnx
import torch
from onnx_tf.backend import prepare

model = torch.load('text_classification.pt')
onnx_model_path = 'text_classification.onnx'

torch.onnx.export(
    model, 
    args=(sample_input, torch.tensor([0])), 
    f=onnx_model_path,      
    opset_version=12,      
    input_names=['input1', 'input2'],  
    output_names=['output'],
    dynamic_axes={'input1': {0: 'batch'}},
)

model = onnx.load("text_classification.onnx")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))

tf_rep = prepare(model)
tf_rep.export_graph('tf_model')

Traceback:

  File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/kristian/origami/model_conversion/convert_model.py", line 11, in <module>
    tf_rep.export_graph("tf_text_classification")
  File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_rep.py", line 143, in export_graph
    signatures=self.tf_module.__call__.get_concrete_function(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1239, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1219, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2523, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2760, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2670, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1247, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 3317, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1233, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1222, in autograph_handler
    return autograph.converted_call(
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 30, in tf____call__
    ag__.for_stmt(ag__.ld(self).graph_def.node, None, loop_body, get_state, set_state, (), {'iterate_names': 'node'})
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 463, in for_stmt
    _py_for_stmt(iter_, extra_test, body, None, None)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 512, in _py_for_stmt
    body(target)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 478, in protected_body
    original_body(protected_iter)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 23, in loop_body
    output_ops = ag__.converted_call(ag__.ld(self).backend._onnx_node_to_tensorflow_op, (ag__.ld(onnx_node), ag__.ld(tensor_dict), ag__.ld(self).handlers), dict(opset=ag__.ld(self).opset, strict=ag__.ld(self).strict), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 62, in tf___onnx_node_to_tensorflow_op
    ag__.if_stmt(ag__.ld(handlers), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 56, in if_body_1
    ag__.if_stmt(ag__.ld(handler), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 48, in if_body
    retval_ = ag__.converted_call(ag__.ld(handler).handle, (ag__.ld(node),), dict(tensor_dict=ag__.ld(tensor_dict), strict=ag__.ld(strict)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 41, in tf__handle
    ag__.if_stmt(ag__.ld(ver_handle), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 33, in if_body
    retval_ = ag__.converted_call(ag__.ld(ver_handle), (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_fileap8kh2xb.py", line 12, in tf__version
    retval_ = ag__.converted_call(ag__.ld(cls)._common, (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 222, in tf___common
    ag__.if_stmt(ag__.ld(scan_outputs_start_index) == ag__.converted_call(ag__.ld(len), (ag__.ld(body).output,), None, fscope), if_body_3, else_body_3, get_state_7, set_state_7, ('do_return', 'retval_'), 2)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
    _py_if_stmt(cond, body, orelse)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
    return body() if cond else orelse()
  File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 210, in else_body_3
    scan_out_final = ag__.converted_call(ag__.ld(tf).cond, (ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(iter_cnt_final), 0), None, fscope), ag__.ld(true_fn), ag__.ld(false_fn)), None, fscope)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
    return f(*args)
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/ops/cond_v2.py", line 379, in _get_compatible_spec
    raise TypeError(f"No common supertype of {spec1} and {spec2}.")
TypeError: in user code:

    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_tf_module.py", line 99, in __call__  *
        output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op  *
        return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/handler.py", line 59, in handle  *
        return ver_handle(node, **kwargs)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 149, in version_11  *
        return cls._common(node, **kwargs)
    File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 139, in _common  *
        scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)

    TypeError: No common supertype of TensorArraySpec(TensorShape([64]), tf.float32, True, True) and TensorArraySpec(TensorShape([64]), tf.float32, None, True).

Link to ONNX file: https://drive.google.com/file/d/1RxFM8igCEmN0lxolUd-DmHsmRFZZWsFQ/view?usp=sharing

I am on TensorFlow flow 2.9.

I was able to get past this issued with Tensorflow 2.8.3, but I cannot run inference on the TensorFlow model

tf_rep = prepare(model)
tf_rep.export_graph('tf_model')
model = tf.saved_model.load('tf_model')
input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
out = model(**{'input1': input_tensor,'input2': offset_tensor})

Traceback

WARNING:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-49-9e8ae2cb5c8d>](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in <module>
      4 input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
      5 offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
----> 6 out = model(**{'input1': input_tensor,'input2': offset_tensor})

2 frames
[/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in call(self, ctx, args, cancellation_manager)
    474     if len(args) != len(self.signature.input_arg):
    475       raise ValueError(
--> 476           f"Signature specifies {len(list(self.signature.input_arg))} "
    477           f"arguments, got: {len(args)}.")
    478 

ValueError: Signature specifies 6 arguments, got: 4.

Downgrading to Tensor Flow 2.8.3 and removing the TensorSpec Wrapper solved this issue

Downgrading to Tensor Flow 2.8.3 and removing the TensorSpec Wrapper solved this issue

how to remove the TensorSpec?