Can't lower onnx.Pad with "wrap" mode
Closed this issue · 1 comments
josel-amd commented
Hi all,
I have a version of the onnx.Pad
operator that uses the "wrap" mode as shown in [0]. As I try to lower it, I run the following command:
torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops reproducer.mlir
Which results in [1]. The problem appears to be that the pad operator gets lowered to a torch.aten.constant_pad_nd
which implements the behavior for the "constant" mode. If I proceed to lower this into llvm and run it in the CPU I get the wrong type of padding (constant instead of wrap) and wrong results.
[0]
func.func @forward(%arg0: tensor<1x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<8xi64> {onnx.name = "pads"}) -> (tensor<1x3x6x7xf32> {onnx.name = "y"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<1x3x4x5xf32> -> !torch.vtensor<[1,3,4,5],f32>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<8xi64> -> !torch.vtensor<[8],si64>
%2 = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
%3 = torch_c.from_builtin_tensor %2 : tensor<1xf32> -> !torch.vtensor<[1],f32>
%none = torch.constant.none
%4 = torch.operator "onnx.Pad"(%0, %1, %3, %none) {torch.onnx.mode = "wrap", torch.onnx.onnx_node_name = "onnx.Pad_1", torch.onnx_meta.version = 19 : i32} : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[1],f32>, !torch.none) -> !torch.vtensor<[1,3,6,7],f32>
%5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[1,3,6,7],f32> -> tensor<1x3x6x7xf32>
return %5 : tensor<1x3x6x7xf32>
}
[1]
func.func @forward(%arg0: tensor<1x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<8xi64> {onnx.name = "pads"}) -> (tensor<1x3x6x7xf32> {onnx.name = "y"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
%int8 = torch.constant.int 8
%int7 = torch.constant.int 7
%int6 = torch.constant.int 6
%int5 = torch.constant.int 5
%int4 = torch.constant.int 4
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%1 = torch_c.from_builtin_tensor %arg0 : tensor<1x3x4x5xf32> -> !torch.vtensor<[1,3,4,5],f32>
%2 = torch_c.from_builtin_tensor %arg1 : tensor<8xi64> -> !torch.vtensor<[8],si64>
%3 = torch_c.from_builtin_tensor %0 : tensor<1xf32> -> !torch.vtensor<[1],f32>
%4 = torch.aten.item %3 : !torch.vtensor<[1],f32> -> !torch.float
%5 = torch.aten.slice.Tensor %2, %int0, %int0, %int1, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.aten.squeeze.dim %5, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%7 = torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int
%8 = torch.aten.slice.Tensor %2, %int0, %int1, %int2, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%9 = torch.aten.squeeze.dim %8, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%10 = torch.aten.item %9 : !torch.vtensor<[],si64> -> !torch.int
%11 = torch.aten.slice.Tensor %2, %int0, %int2, %int3, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%12 = torch.aten.squeeze.dim %11, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%13 = torch.aten.item %12 : !torch.vtensor<[],si64> -> !torch.int
%14 = torch.aten.slice.Tensor %2, %int0, %int3, %int4, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%15 = torch.aten.squeeze.dim %14, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%16 = torch.aten.item %15 : !torch.vtensor<[],si64> -> !torch.int
%17 = torch.aten.slice.Tensor %2, %int0, %int4, %int5, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%18 = torch.aten.squeeze.dim %17, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%19 = torch.aten.item %18 : !torch.vtensor<[],si64> -> !torch.int
%20 = torch.aten.slice.Tensor %2, %int0, %int5, %int6, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%21 = torch.aten.squeeze.dim %20, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%22 = torch.aten.item %21 : !torch.vtensor<[],si64> -> !torch.int
%23 = torch.aten.slice.Tensor %2, %int0, %int6, %int7, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%24 = torch.aten.squeeze.dim %23, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%25 = torch.aten.item %24 : !torch.vtensor<[],si64> -> !torch.int
%26 = torch.aten.slice.Tensor %2, %int0, %int7, %int8, %int1 : !torch.vtensor<[8],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%27 = torch.aten.squeeze.dim %26, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
%28 = torch.aten.item %27 : !torch.vtensor<[],si64> -> !torch.int
%29 = torch.prim.ListConstruct %16, %28, %13, %25, %10, %22, %7, %19 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%30 = torch.aten.constant_pad_nd %1, %29, %4 : !torch.vtensor<[1,3,4,5],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[1,3,6,7],f32>
%31 = torch_c.to_builtin_tensor %30 : !torch.vtensor<[1,3,6,7],f32> -> tensor<1x3x6x7xf32>
return %31 : tensor<1x3x6x7xf32>
}