[Reference interpreter] Remove need for canonicalizing away shape patterns
Opened this issue · 0 comments
sdasgup3 commented
For evaluating quantized program, the Stablehlo reference interpreter depends on StablehloLegalizeQuantToInt pass which introduces Chlo broadcast operations for scale multiplication/division and zero-point addition. Legalizing the Chlo operations to StableHLO operations amount to including shape operations which needs to be canonicalized away using a bunch of canonicalization passes.
We believe that for statically shaped program we can avoid the need for chlo broadcast operations altogether and that would simply the decomposition pipeline for quantized operations.
For example, the following program
func.func @quantized_add() -> tensor<2xf32> {
%operand1 = stablehlo.constant dense<[1.0, 2.0]> : tensor<2xf32>
%operand2 = stablehlo.constant dense<[3.0, 4.0]> : tensor<2xf32>
%q_operand1 = "stablehlo.uniform_quantize"(%operand1) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.1:-30>>
%q_operand2 = "stablehlo.uniform_quantize"(%operand2) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%result = "stablehlo.add"(%q_operand1, %q_operand2) : (tensor<2x!quant.uniform<i8:f32, 0.1:-30>>, tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2x!quant.uniform<i8:f32, 0.5:-20>>
%result_f = "stablehlo.uniform_dequantize"(%result) : (tensor<2x!quant.uniform<i8:f32, 0.5:-20>>) -> tensor<2xf32>
func.return %result_f: tensor<2xf32>
}
needs to go through the following passes to convert to a fully stablehlo program
--stablehlo-legalize-quant-to-int --chlo-legalize-to-stablehlo --canonicalize --shape-legalize-to-stablehlo --stablehlo-canonicalize-dynamism