I want to convert "context_model.pkl" to "context_model.onnx".
PpEeIi opened this issue · 1 comments
PpEeIi commented
I try to run the following code to convert the .pkl file to an .onnx file:
import torch
from models.context_model import ContextModel
randinput = torch.randn(1, 126, 128, device='cuda:0')
T = randinput.shape[1]
one_hot = torch.randn(1, T, 16, 128, device='cuda:0')
model_path = "pretrained_models"
context_model = ContextModel(classes=128, heads=16, audio_dim=128)
context_model.load(model_path)
context_model.cuda().eval()
torch.onnx.export(context_model, (0, 0, one_hot, randinput), "context_model_onnx.onnx", verbose=True, opset_version=11,
input_names=['t', 'h', 'one_hot', 'audio'])
Where: In the "context_model.py", I replace the contents under “def forward(self, expression_one_hot: th.Tensor, audio_code: th.Tensor):” function with the contents under “def _forward_inference(self, t: int, h: int, context: th.Tensor, audio: th.Tensor):” function.
Then the error is reported as follows:
E:\leaf\meshtalk\models\context_model.py:75: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if self.historic_t < t:
E:\leaf\meshtalk\models\context_model.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if h > 0:
E:\leaf\meshtalk\models\context_model.py:100: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if self.kernel_size > 0 and self.historic_t < t:
E:\leaf\meshtalk\models\context_model.py:103: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if h.shape[-1] < self.receptive_field() - 1:
The reason I found on the Internet is that the input variable cannot be used in if statements.
How to solve this problem?
alexanderrichard commented
It seems that onnx only supports static compute graphs. I did not yet work with onnx, so I don't have a solution for you.