PyTorch - > ONNX -> TF
xkasberg opened this issue · 3 comments
xkasberg commented
Hello, I am converting a boiler plate Text Classification model from PyTorch to ONNX to Tensor Flow on Google Collab.
The conversion from PyTorch to ONNX works fine, but I get his trace back when converting to TensorFlow:
import os
import onnx
import torch
from onnx_tf.backend import prepare
model = torch.load('text_classification.pt')
onnx_model_path = 'text_classification.onnx'
torch.onnx.export(
model,
args=(sample_input, torch.tensor([0])),
f=onnx_model_path,
opset_version=12,
input_names=['input1', 'input2'],
output_names=['output'],
dynamic_axes={'input1': {0: 'batch'}},
)
model = onnx.load("text_classification.onnx")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
tf_rep = prepare(model)
tf_rep.export_graph('tf_model')
Traceback:
File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/homebrew/Cellar/python@3.10/3.10.8/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Users/kristian/origami/model_conversion/convert_model.py", line 11, in <module>
tf_rep.export_graph("tf_text_classification")
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_rep.py", line 143, in export_graph
signatures=self.tf_module.__call__.get_concrete_function(
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1239, in get_concrete_function
concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 1219, in _get_concrete_function_garbage_collected
self._initialize(args, kwargs, add_initializers_to=initializers)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 785, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2523, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2760, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 2670, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1247, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/def_function.py", line 677, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/eager/function.py", line 3317, in bound_method_wrapper
return wrapped_fn(*args, **kwargs)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1233, in autograph_handler
raise e.ag_error_metadata.to_exception(e)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1222, in autograph_handler
return autograph.converted_call(
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
result = converted_f(*effective_args, **kwargs)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 30, in tf____call__
ag__.for_stmt(ag__.ld(self).graph_def.node, None, loop_body, get_state, set_state, (), {'iterate_names': 'node'})
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 463, in for_stmt
_py_for_stmt(iter_, extra_test, body, None, None)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 512, in _py_for_stmt
body(target)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 478, in protected_body
original_body(protected_iter)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filejcsv_ou1.py", line 23, in loop_body
output_ops = ag__.converted_call(ag__.ld(self).backend._onnx_node_to_tensorflow_op, (ag__.ld(onnx_node), ag__.ld(tensor_dict), ag__.ld(self).handlers), dict(opset=ag__.ld(self).opset, strict=ag__.ld(self).strict), fscope)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
result = converted_f(*effective_args, **kwargs)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 62, in tf___onnx_node_to_tensorflow_op
ag__.if_stmt(ag__.ld(handlers), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
_py_if_stmt(cond, body, orelse)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
return body() if cond else orelse()
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 56, in if_body_1
ag__.if_stmt(ag__.ld(handler), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
_py_if_stmt(cond, body, orelse)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
return body() if cond else orelse()
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filep_m8s0gh.py", line 48, in if_body
retval_ = ag__.converted_call(ag__.ld(handler).handle, (ag__.ld(node),), dict(tensor_dict=ag__.ld(tensor_dict), strict=ag__.ld(strict)), fscope)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
result = converted_f(*effective_args, **kwargs)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 41, in tf__handle
ag__.if_stmt(ag__.ld(ver_handle), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
_py_if_stmt(cond, body, orelse)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
return body() if cond else orelse()
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_file_9iyzg_3.py", line 33, in if_body
retval_ = ag__.converted_call(ag__.ld(ver_handle), (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
result = converted_f(*effective_args, **kwargs)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_fileap8kh2xb.py", line 12, in tf__version
retval_ = ag__.converted_call(ag__.ld(cls)._common, (ag__.ld(node),), dict(**ag__.ld(kwargs)), fscope)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
result = converted_f(*effective_args, **kwargs)
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 222, in tf___common
ag__.if_stmt(ag__.ld(scan_outputs_start_index) == ag__.converted_call(ag__.ld(len), (ag__.ld(body).output,), None, fscope), if_body_3, else_body_3, get_state_7, set_state_7, ('do_return', 'retval_'), 2)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1363, in if_stmt
_py_if_stmt(cond, body, orelse)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1416, in _py_if_stmt
return body() if cond else orelse()
File "/var/folders/_l/1kjtk38d4t98jp53gv3rhzdm0000gn/T/__autograph_generated_filer79hpz_8.py", line 210, in else_body_3
scan_out_final = ag__.converted_call(ag__.ld(tf).cond, (ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(iter_cnt_final), 0), None, fscope), ag__.ld(true_fn), ag__.ld(false_fn)), None, fscope)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
return _call_unconverted(f, args, kwargs, options)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
return f(*args)
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/kristian/.local/share/virtualenvs/model_conversion-WzW74XZf/lib/python3.10/site-packages/tensorflow/python/ops/cond_v2.py", line 379, in _get_compatible_spec
raise TypeError(f"No common supertype of {spec1} and {spec2}.")
TypeError: in user code:
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend_tf_module.py", line 99, in __call__ *
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op *
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/handler.py", line 59, in handle *
return ver_handle(node, **kwargs)
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 149, in version_11 *
return cls._common(node, **kwargs)
File "/Users/kristian/origami/model_conversion/onnx-tensorflow/onnx_tf/handlers/backend/loop.py", line 139, in _common *
scan_out_final = tf.cond(tf.greater(iter_cnt_final, 0), true_fn, false_fn)
TypeError: No common supertype of TensorArraySpec(TensorShape([64]), tf.float32, True, True) and TensorArraySpec(TensorShape([64]), tf.float32, None, True).
Link to ONNX file: https://drive.google.com/file/d/1RxFM8igCEmN0lxolUd-DmHsmRFZZWsFQ/view?usp=sharing
xkasberg commented
I am on TensorFlow flow 2.9.
I was able to get past this issued with Tensorflow 2.8.3, but I cannot run inference on the TensorFlow model
tf_rep = prepare(model)
tf_rep.export_graph('tf_model')
model = tf.saved_model.load('tf_model')
input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
out = model(**{'input1': input_tensor,'input2': offset_tensor})
Traceback
WARNING:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-49-9e8ae2cb5c8d>](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in <module>
4 input_tensor = tf.TensorSpec.from_tensor(tf.constant(text_pipeline(ex_text_str), dtype=tf.int64), name='input1')
5 offset_tensor = tf.TensorSpec.from_tensor(tf.constant([0], dtype=tf.int64), name='input2')
----> 6 out = model(**{'input1': input_tensor,'input2': offset_tensor})
2 frames
[/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py](https://vyfdxty3b2s-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221013-060049-RC00_480849705#) in call(self, ctx, args, cancellation_manager)
474 if len(args) != len(self.signature.input_arg):
475 raise ValueError(
--> 476 f"Signature specifies {len(list(self.signature.input_arg))} "
477 f"arguments, got: {len(args)}.")
478
ValueError: Signature specifies 6 arguments, got: 4.
xkasberg commented
Downgrading to Tensor Flow 2.8.3 and removing the TensorSpec Wrapper solved this issue
andre20000131 commented
Downgrading to Tensor Flow 2.8.3 and removing the TensorSpec Wrapper solved this issue
how to remove the TensorSpec?