Failed to lower batch normalization (in training mode)
josel-amd opened this issue · 0 comments
josel-amd commented
Hi all,
I tried to lower a version of 'torch.operator "onnx.BatchNormalization"' that has torch.onnx.training_mode = 1 : si64
but the conversion failed. This is not very important to us but I still wanted to raised awareness for it! :)
Command used to reproduce
torch-mlir-opt --convert-torch-onnx-to-torch batch_normalization.mlir -debug
Observed result
//===-------------------------------------------===//
Legalizing operation : 'torch.operator'(0x556505c2ab00) {
%5:3 = "torch.operator"(%0, %1, %2, %3, %4) <{name = "onnx.BatchNormalization"}> {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.operator -> ()' {
Trying to match ""
** Failure : unsupported conversion: training = true
: conversion failed to apply: "onnx.BatchNormalization", sinceVersion=15
** Failure : no matching versioned converter
"" result 0
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
batch_normalization.mlir:7:10: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
%5:3 = torch.operator "onnx.BatchNormalization"(%0, %1, %2, %3, %4) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)
^
atch_normalization.mlir:7:10: note: see current operation: %5:3 = "torch.operator"(%0, %1, %2, %3, %4) <{name = "onnx.BatchNormalization"}> {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)
Reproducer:
func.func @forward(%arg0: tensor<2x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<3xf32> {onnx.name = "s"}, %arg2: tensor<3xf32> {onnx.name = "bias"}, %arg3: tensor<3xf32> {onnx.name = "mean"}, %arg4: tensor<3xf32> {onnx.name = "var"}) -> (tensor<2x3x4x5xf32> {onnx.name = "y"}, tensor<3xf32> {onnx.name = "output_mean"}, tensor<3xf32> {onnx.name = "output_var"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3x4x5xf32> -> !torch.vtensor<[2,3,4,5],f32>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3xf32> -> !torch.vtensor<[3],f32>
%2 = torch_c.from_builtin_tensor %arg2 : tensor<3xf32> -> !torch.vtensor<[3],f32>
%3 = torch_c.from_builtin_tensor %arg3 : tensor<3xf32> -> !torch.vtensor<[3],f32>
%4 = torch_c.from_builtin_tensor %arg4 : tensor<3xf32> -> !torch.vtensor<[3],f32>
%5:3 = torch.operator "onnx.BatchNormalization"(%0, %1, %2, %3, %4) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)
%6 = torch_c.to_builtin_tensor %5#0 : !torch.vtensor<[2,3,4,5],f32> -> tensor<2x3x4x5xf32>
%7 = torch_c.to_builtin_tensor %5#1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
%8 = torch_c.to_builtin_tensor %5#2 : !torch.vtensor<[3],f32> -> tensor<3xf32>
return %6, %7, %8 : tensor<2x3x4x5xf32>, tensor<3xf32>, tensor<3xf32>
}