llvm/torch-mlir

Failed to lower batch normalization (in training mode)

josel-amd opened this issue · 0 comments

Hi all,

I tried to lower a version of 'torch.operator "onnx.BatchNormalization"' that has torch.onnx.training_mode = 1 : si64 but the conversion failed. This is not very important to us but I still wanted to raised awareness for it! :)

Command used to reproduce

torch-mlir-opt --convert-torch-onnx-to-torch batch_normalization.mlir -debug

Observed result

//===-------------------------------------------===//
Legalizing operation : 'torch.operator'(0x556505c2ab00) {
  %5:3 = "torch.operator"(%0, %1, %2, %3, %4) <{name = "onnx.BatchNormalization"}> {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)

  * Fold {
  } -> FAILURE : unable to fold

  * Pattern : 'torch.operator -> ()' {
Trying to match ""
    ** Failure : unsupported conversion: training = true
: conversion failed to apply: "onnx.BatchNormalization", sinceVersion=15
    ** Failure : no matching versioned converter
"" result 0
  } -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
batch_normalization.mlir:7:10: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
  %5:3 = torch.operator "onnx.BatchNormalization"(%0, %1, %2, %3, %4) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) 
         ^
atch_normalization.mlir:7:10: note: see current operation: %5:3 = "torch.operator"(%0, %1, %2, %3, %4) <{name = "onnx.BatchNormalization"}> {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>)

Reproducer:

func.func @forward(%arg0: tensor<2x3x4x5xf32> {onnx.name = "x"}, %arg1: tensor<3xf32> {onnx.name = "s"}, %arg2: tensor<3xf32> {onnx.name = "bias"}, %arg3: tensor<3xf32> {onnx.name = "mean"}, %arg4: tensor<3xf32> {onnx.name = "var"}) -> (tensor<2x3x4x5xf32> {onnx.name = "y"}, tensor<3xf32> {onnx.name = "output_mean"}, tensor<3xf32> {onnx.name = "output_var"}) attributes {torch.onnx_meta.opset_version = 19 : si64} { 
  %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3x4x5xf32> -> !torch.vtensor<[2,3,4,5],f32>
  %1 = torch_c.from_builtin_tensor %arg1 : tensor<3xf32> -> !torch.vtensor<[3],f32>
  %2 = torch_c.from_builtin_tensor %arg2 : tensor<3xf32> -> !torch.vtensor<[3],f32>
  %3 = torch_c.from_builtin_tensor %arg3 : tensor<3xf32> -> !torch.vtensor<[3],f32>
  %4 = torch_c.from_builtin_tensor %arg4 : tensor<3xf32> -> !torch.vtensor<[3],f32>
  %5:3 = torch.operator "onnx.BatchNormalization"(%0, %1, %2, %3, %4) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.momentum = 0.899999976 : f32, torch.onnx.onnx_node_name = "onnx.BatchNormalization_0", torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) 
  %6 = torch_c.to_builtin_tensor %5#0 : !torch.vtensor<[2,3,4,5],f32> -> tensor<2x3x4x5xf32>
  %7 = torch_c.to_builtin_tensor %5#1 : !torch.vtensor<[3],f32> -> tensor<3xf32>
  %8 = torch_c.to_builtin_tensor %5#2 : !torch.vtensor<[3],f32> -> tensor<3xf32>
  return %6, %7, %8 : tensor<2x3x4x5xf32>, tensor<3xf32>, tensor<3xf32>
}