Torch to ONNX conversion is very slow
matt-kh opened this issue · 0 comments
matt-kh commented
When converting Torch model to ONNX, the conversion ran for more than 8 hours without any exception.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Model Initialization
encoder_blocks = [64,] * 4
weights_url = "https://github.com/JJGO/UniverSeg/releases/download/weights/universeg_v1_nf64_ss64_STA.pt"
model = UniverSeg(encoder_blocks=encoder_blocks)
state_dict = torch.hub.load_state_dict_from_url(weights_url)
model.load_state_dict(state_dict)
_ = model.to(device)
_ = model.eval()
# Dummy inputs
torch.manual_seed(42)
target_image = torch.randn(1, 1, 256, 256, device=device)
support_images = torch.randn(1, 64, 1, 256, 256, device=device)
support_labels = torch.randn(1, 64, 1, 256, 256, device=device)
# Onnx conversion
input_names = ["target_image", "support_images", "support_labels"]
output_names = ["logits"]
torch.onnx.export(
model=model,
args=(target_image, support_images, support_labels),
f=export_path,
input_names=input_names,
output_names=output_names,
export_params=True,
do_constant_folding=True,
dynamic_axes={
"target_image":{0: "batch", 1: "channel", 2: "height", 3: "width"},
"support_images": {0: "batch", 1:"support_size", 2: "channel", 3: "height", 4: "width"},
"support_labels": {0: "batch", 1:"support_size", 3: "height", 4: "width"},
"logits" : {0: "batch", 2: "height", 3: "width"}
},
verbose=True,
opset_version=16,
)
These warnings from Einops package .../einops/einops.py are found during conversion:
TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
However, no exceptions raised from this code. I am not sure whether the tracer warnings caused by Einops are related to the indefinite run of torch.onnx.export().
I appreciate any help for this issue, thank you.