facebookresearch/meshtalk

I want to convert "context_model.pkl" to "context_model.onnx".

PpEeIi opened this issue · 1 comments

I try to run the following code to convert the .pkl file to an .onnx file:

import torch
from models.context_model import ContextModel

randinput = torch.randn(1, 126, 128, device='cuda:0')
T = randinput.shape[1]
one_hot = torch.randn(1, T, 16, 128, device='cuda:0')
model_path = "pretrained_models"
context_model = ContextModel(classes=128, heads=16, audio_dim=128)
context_model.load(model_path)
context_model.cuda().eval()
torch.onnx.export(context_model, (0, 0, one_hot, randinput), "context_model_onnx.onnx", verbose=True, opset_version=11,
                  input_names=['t', 'h', 'one_hot', 'audio'])

Where: In the "context_model.py", I replace the contents under “def forward(self, expression_one_hot: th.Tensor, audio_code: th.Tensor):” function with the contents under “def _forward_inference(self, t: int, h: int, context: th.Tensor, audio: th.Tensor):” function.

Then the error is reported as follows:

E:\leaf\meshtalk\models\context_model.py:75: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.historic_t < t:
E:\leaf\meshtalk\models\context_model.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h > 0:
E:\leaf\meshtalk\models\context_model.py:100: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.kernel_size > 0 and self.historic_t < t:
E:\leaf\meshtalk\models\context_model.py:103: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h.shape[-1] < self.receptive_field() - 1:

The reason I found on the Internet is that the input variable cannot be used in if statements.
How to solve this problem?

It seems that onnx only supports static compute graphs. I did not yet work with onnx, so I don't have a solution for you.