llvm/torch-mlir

Failed to convert "onnx.Pad" into torch

josel-amd opened this issue · 3 comments

Hi all,

I'm trying to import a very simple onnx model that only contains an onnx.Pad operator see [0] and then convert it into torch. For that purpose I perform the commands listed below

$ PYTHONPATH=tools/torch-mlir/python_packages/torch_mlir python -m torch_mlir.tools.import_onnx model.onnx -o onnx.mlir
$ torch-mlir-opt --convert-torch-onnx-to-torch onnx.mlir -debug

The onnx model can be found here onnxmodel.zip and the textual version is shown below. The generated IR is listed in [1].

<
   ir_version: 3,
   opset_import: ["" : 6],
   producer_name: "pytorch",
   producer_version: "0.3"
>
torch-jit-export (float[1,1,2,4] 0) => (float[1,1,3,9] 1) {
   1 = Pad <mode = "reflect", pads = [0, 0, 0, 2, 0, 0, 1, 3]> (0)
}

The importer works (i.e.: it does not fail or crash) and it produces the IR in [1]. When I try to convert the IR produced by the importer to torch, I get the following output:

//===-------------------------------------------===//
Legalizing operation : 'torch.operator'(0x559fbc3406b0) {
  %1 = "torch.operator"(%arg0) <{name = "onnx.Pad"}> {torch.onnx.mode = "reflect", torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 2 : si64, 0 : si64, 0 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[1,1,2,4],f32>) -> !torch.vtensor<[1,1,3,9],f32>

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.operator -> ()' {
Trying to match ""
: conversion failed to apply: "onnx.Pad", sinceVersion=1
    ** Failure : no matching versioned converter
"" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
onnx.mlir:4:10: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "reflect", torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 2 : si64, 0 : si64, 0 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[1,1,2,4],f32>) -> !torch.vtensor<[1,1,3,9],f32> 
         ^
onnx.mlir:4:10: note: see current operation: %1 = "torch.operator"(%arg0) <{name = "onnx.Pad"}> {torch.onnx.mode = "reflect", torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 2 : si64, 0 : si64, 0 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[1,1,2,4],f32>) -> !torch.vtensor<[1,1,3,9],f32>

The problem seems to be many-fold but what stands out to me in this issue is that the conversion for the torch.operator "onnx.Pad" operator to torch seems to imply that the pads should be an argument and not an attribute of the operator. At first sight, it looks like there is a mismatch between what the importer produces and what the --convert-torch-onnx-to-onnx pass expects as input.

[0]

image

[1]

module {
  func.func @"torch-jit-export"(%arg0: !torch.vtensor<[1,1,2,4],f32>) -> !torch.vtensor<[1,1,3,9],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.3"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "reflect", torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 2 : si64, 0 : si64, 0 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[1,1,2,4],f32>) -> !torch.vtensor<[1,1,3,9],f32> 
    return %0 : !torch.vtensor<[1,1,3,9],f32>
  }
}

@josel-amd It seems that the versioning for this conversion is a bit too liberal, since the pads attribute was converted to an input tensor as of version 11 (which is what our pattern seems to be designed for).

Would you like us to support the older operator signature for onnx.Pad? If not, I'll simply restrict the versioning to opset_version>=11 and recommend that you try updating the opset version of your onnx model and seeing if that gives you valid IR.

Hi @zjgarvey! it would be great it we could also support OpSet v6 assuming that the lowering is trivial. The issues I've been raising are indeed not the most common use cases which is exactly what I'm going for. I unfortunately also don't control the onnx model so I can't modify it. Let me know if this is somehow prohibitive.