openxla/stablehlo

[Reference interpreter] Remove need for canonicalizing away shape patterns

Opened this issue · 0 comments

For evaluating quantized program, the Stablehlo reference interpreter depends on StablehloLegalizeQuantToInt pass which introduces Chlo broadcast operations for scale multiplication/division and zero-point addition. Legalizing the Chlo operations to StableHLO operations amount to including shape operations which needs to be canonicalized away using a bunch of canonicalization passes.

We believe that for statically shaped program we can avoid the need for chlo broadcast operations altogether and that would simply the decomposition pipeline for quantized operations.

For example, the following program

func.func @quantized_add() -> tensor<2xf32> {
  %operand1 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
  %operand2 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
  %q_operand1 = "stablehlo.uniform_quantize"(%operand1) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:-30>>
  %q_operand2 = "stablehlo.uniform_quantize"(%operand2) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
  %result = "stablehlo.add"(%q_operand1, %q_operand2) : (tensor<2x!quant.uniform<i8:f32, 0.1:-30>>, tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
  %result_f = "stablehlo.uniform_dequantize"(%result) : (tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2xf32>
  func.return %result_f: tensor<2xf32>
}

needs to go through the following passes to convert to a fully stablehlo program

--stablehlo-legalize-quant-to-int --chlo-legalize-to-stablehlo --canonicalize --shape-legalize-to-stablehlo --stablehlo-canonicalize-dynamism