
[BUG] is generating onnxruntime.capi.onnxruntime_pybind11_state.Fail

lovettchris opened this issue · 1 comments

Describe the bug
A clear and concise description of what the bug is.

To Reproduce

from archai.discrete_search.evaluators.onnx_model import AvgOnnxLatency
from archai.discrete_search.search_spaces.config import ArchConfig
from search_space.hgnet import StackedHourglass
from archai.discrete_search.api import ArchaiModel

arch_config = ArchConfig.from_file('config.json')
model = StackedHourglass(arch_config, num_classes=18)
archid = "123"
am = ArchaiModel(model, archid)
input_shape = (1, 3, 256, 256)

lat = AvgOnnxLatency(input_shape=input_shape, export_kwargs={'opset_version': 11})

Expected behavior
Should just work.



Desktop (please complete the following information):

  • OS: Windows 11
  • Virtual Environment: conda
  • Python Version: 3.10

Additional context

Removing the tmpfile and writing to a file named "model.onnx" and then loading that in the ONNX inference session works fine. So there is some weird interplay between with tempfile.NamedTemporaryFile(delete=False) as tmp_file: and the ONNX inference session. Perhaps we could just write the onnx model to the given --output folder somewhere near the checkpoints?

I posted a minimal repro here, we'll see what they say: