llvm/torch-mlir

Can't lower onnx.Pad with "wrap" mode

Closed this issue · 1 comments

Hi all,

I have a version of the onnx.Pad operator that uses the "wrap" mode as shown in [0]. As I try to lower it, I run the following command:

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops reproducer.mlir

Which results in [1]. The problem appears to be that the pad operator gets lowered to a torch.aten.constant_pad_nd which implements the behavior for the "constant" mode. If I proceed to lower this into llvm and run it in the CPU I get the wrong type of padding (constant instead of wrap) and wrong results.

[0]

  func.func @forward(%arg0: tensor<1x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<8xi64> {onnx.name = "pads"}) -> (tensor<1x3x6x7xf32> {onnx.name = "y"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<1x3x4x5xf32> -> !torch.vtensor<[1,3,4,5],f32>
    %1 = torch_c.from_builtin_tensor %arg1 : tensor<8xi64> -> !torch.vtensor<[8],si64>
    %2 = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
    %3 = torch_c.from_builtin_tensor %2 : tensor<1xf32> -> !torch.vtensor<[1],f32>
    %none = torch.constant.none
    %4 = torch.operator "onnx.Pad"(%0, %1, %3, %none) {torch.onnx.mode = "wrap", torch.onnx.onnx_node_name = "onnx.Pad_1", torch.onnx_meta.version = 19 : i32} : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[1],f32>, !torch.none) -> !torch.vtensor<[1,3,6,7],f32>
    %5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[1,3,6,7],f32> -> tensor<1x3x6x7xf32>
    return %5 : tensor<1x3x6x7xf32>
  }

[1]

func.func @forward(%arg0: tensor<1x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<8xi64> {onnx.name = "pads"}) -> (tensor<1x3x6x7xf32> {onnx.name = "y"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
    %int8 = torch.constant.int 8
    %int7 = torch.constant.int 7
    %int6 = torch.constant.int 6
    %int5 = torch.constant.int 5
    %int4 = torch.constant.int 4
    %int3 = torch.constant.int 3
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %1 = torch_c.from_builtin_tensor %arg0 : tensor<1x3x4x5xf32> -> !torch.vtensor<[1,3,4,5],f32>
    %2 = torch_c.from_builtin_tensor %arg1 : tensor<8xi64> -> !torch.vtensor<[8],si64>
    %3 = torch_c.from_builtin_tensor %0 : tensor<1xf32> -> !torch.vtensor<[1],f32>
    %4 = torch.aten.item %3 : !torch.vtensor<[1],f32> -> !torch.float
    %5 = torch.aten.slice.Tensor %2, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %6 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
    %8 = torch.aten.slice.Tensor %2, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %9 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
    %11 = torch.aten.slice.Tensor %2, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %12 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
    %14 = torch.aten.slice.Tensor %2, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %15 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
    %17 = torch.aten.slice.Tensor %2, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %18 = torch.aten.squeeze.dim %17, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %19 = torch.aten.item %18 : !torch.vtensor<[],si64> -> !torch.int
    %20 = torch.aten.slice.Tensor %2, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %21 = torch.aten.squeeze.dim %20, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %22 = torch.aten.item %21 : !torch.vtensor<[],si64> -> !torch.int
    %23 = torch.aten.slice.Tensor %2, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %24 = torch.aten.squeeze.dim %23, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %25 = torch.aten.item %24 : !torch.vtensor<[],si64> -> !torch.int
    %26 = torch.aten.slice.Tensor %2, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
    %27 = torch.aten.squeeze.dim %26, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %28 = torch.aten.item %27 : !torch.vtensor<[],si64> -> !torch.int
    %29 = torch.prim.ListConstruct %16, %28, %13, %25, %10, %22, %7, %19 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %30 = torch.aten.constant_pad_nd %1, %29, %4 : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,6,7],f32>
    %31 = torch_c.to_builtin_tensor %30 : !torch.vtensor<[1,3,6,7],f32> -> tensor<1x3x6x7xf32>
    return %31 : tensor<1x3x6x7xf32>
  }

This seems to be addressed in #3528. Closing the issue.