Convert to ONNX
AryaAftab opened this issue · 5 comments
AryaAftab commented
Hi,
I want to convert the model to onnx format but I get an error, can anyone help me to solve the problem?
Note: I am using Colab for model loading and converting.
Conversion code:
dummy_input = torch.randn(1, 48, 3, 719, 1282, device="cuda")
input_names = ["input"]
output_names = ["output_tracker", "output_visib"]
dynamic_axes_dict = {
'input': {
0: 'bs'
},
'output_tracker': {
0: 'bs'
},
'output_visib': {
0: 'bs'
}
}
torch.onnx.export(model,
dummy_input,
"cotracker.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes_dict,
export_params=True,
)
Error:
WARNING:py.warnings:/content/co-tracker/cotracker/predictor.py:47: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if queries is None and grid_size == 0:
0%| | 0/1764 [00:00<?, ?it/s]WARNING:py.warnings:/content/co-tracker/cotracker/predictor.py:106: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert B == 1
WARNING:py.warnings:/content/co-tracker/cotracker/predictor.py:115: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert D == 3
WARNING:py.warnings:/content/co-tracker/cotracker/models/core/cotracker/cotracker.py:224: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert B == 1
WARNING:py.warnings:/content/co-tracker/cotracker/models/core/cotracker/cotracker.py:264: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
while ind < T - self.S // 2:
WARNING:py.warnings:/content/co-tracker/cotracker/predictor.py:144: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if backward_tracking:
100%|██████████| 1764/1764 [01:33<00:00, 18.82it/s]
WARNING:py.warnings:/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:619: UserWarning: ONNX Preprocess - Removing mutation from node aten::fill_ on block input: 'grid_query_frame'. This changes graph semantics. (Triggered internally at ../torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp:350.)
_C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-11-d51d0e1e59b2>](https://localhost:8080/#) in <cell line: 13>()
11 }
12
---> 13 torch.onnx.export(model,
14 dummy_input,
15 "cotracker.onnx",
16 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/jit_utils.py](https://localhost:8080/#) in _add_attribute(node, key, value, aten)
353 else:
354 kind = "i"
--> 355 return getattr(node, f"{kind}_")(name, value)
356
357
TypeError: z_(): incompatible function arguments. The following argument types are supported:
1. (self: torch._C.Node, arg0: str, arg1: torch.Tensor) -> torch._C.Node
Invoked with: %367 : Tensor = onnx::Constant(), scope: cotracker.predictor.CoTrackerPredictor::
, 'value', 0
(Occurred when translating repeat_interleave).
nikitakaraevv commented
Hi @AryaAftab,
I would suggest you convert the following model class to .onnx instead:
You can try something like this:
predictor = CoTrackerPredictor(checkpoint)
#All inputs should be resized to 384x512
dummy_input = torch.randn(1, 8, 3, 384, 512, device="cuda")
# We take a video and queried points as input
input_names = ["input_video", "input_queries"]
output_names = ["output_tracks", "output_feature", "output_visib", "output_metadata"]
# Video length is also dynamic
dynamic_axes_dict = {
'input_video': {
0: 'batch_size',
1: 'video_len'
},
'input_queries': {
0: 'batch_size',
1: 'video_len'
},
}
torch.onnx.export(predictor.model,
dummy_input,
"cotracker.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes_dict,
export_params=True,
)
Could you tell me if this approach works? I haven't tried to convert CoTracker to .onnx yet. It might need some refactoring.
AryaAftab commented
Hi @nikitakaraevv, Thanks for your response,
I tried this with several modifications, but I got some errors,
from cotracker.models.core.cotracker.cotracker import CoTracker
from cotracker.models.build_cotracker import build_cotracker_stride_4_wind_8
model = build_cotracker_stride_4_wind_8(
checkpoint=os.path.join(
'./checkpoints/cotracker_stride_4_wind_8.pth'
)
)
device = "cpu"
#All inputs should be resized to 384x512
dummy_input_1 = torch.randn(1, 8, 3, 384, 512, device=device)
dummy_input_2 = torch.randn(1, 8, 3, device=device)
# We take a video and queried points as input
input_names = ["input_video", "input_queries"]
output_names = ["output_tracks", "output_feature", "output_visib", "output_metadata"]
# Video length is also dynamic
dynamic_axes_dict = {
'input_video': {
0: 'batch_size',
1: 'video_len'
},
'input_queries': {
0: 'batch_size',
1: 'video_len'
},
}
torch.onnx.export(model,
(dummy_input_1, dummy_input_2),
"cotracker.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes_dict,
export_params=True,
opset_version=16,
)
Error:
============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 1 ERROR ========================
ERROR: missing-standard-symbolic-function
=========================================
Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 16 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>
---------------------------------------------------------------------------
UnsupportedOperatorError Traceback (most recent call last)
[<ipython-input-13-9985d6d41604>](https://localhost:8080/#) in <cell line: 21>()
19 }
20
---> 21 torch.onnx.export(model,
22 (dummy_input_1, dummy_input_2),
23 "cotracker.onnx",
4 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _run_symbolic_function(graph, block, node, inputs, env, operator_export_type)
1899 return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]
1900
-> 1901 raise errors.UnsupportedOperatorError(
1902 symbolic_function_name,
1903 opset_version,
UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 16 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
nikitakaraevv commented
Thank you, @AryaAftab!
The issue is discussed here: pytorch/pytorch#97262
Is it helpful?