polynomial-to-standard: error on integer width for poly.add when ring coefficientModulus width is less than coefficient width
Opened this issue · 3 comments
When playing around tests/bgv/to_polynomial.mlir
with additional -convert-elementwise-to-affine -polynomial-to-standard
, error happened
error: 'arith.extsi' op operand type 'tensor<1024xi32>' and result type 'tensor<1024xi26>' are cast incompatible
%mul = bgv.mul %x, %y : (!ct1, !ct1) -> !ct2
^
note: see current operation: %19 = "arith.extsi"(%13) : (tensor<1024xi32>) -> tensor<1024xi26>
Reduced to a minimal working example with argument -polynomial-to-standard
, where changing the line to !p = !p2
resolves the problem:
#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>
!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>
!p = !p1
module {
func.func @polymul(%x : !p, %y : !p) -> (!p) {
%add = polynomial.add %x, %y : !p
return %add : !p
}
}
The related code is the following
heir/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp
Lines 522 to 543 in b79bccf
It assumes either coefficientModulus is a power-of-two or its width is the same as coefficientType.
Maybe we should check that coefficientType has the same width as coefficientModulus if it is a prime.
Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified? Currently extsi/extui is everywhere.
cc @AlexanderViand @inbelic on this issue
Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified?
I think so - but I will need to catch up on this particular pass pipeline.
Thanks for bringing this up! I’ll add it to the list of “issues that show why we really need formal verification around mod_arith stuff”!
I agree @ switching this pipeline over to mod_arith would be the best thing to make sure at least these kinds of issues only occur in one place xD
polynomial.mul_scalar
also has an erroneous lowering (the test only covers power-of-two branch):
Similar code as above, and would report error no matter !p1
or !p2
:
#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>
!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>
!p = !p1
module {
func.func @polymul(%x : !p, %y : !p) -> (!p) {
%add = polynomial.sub %x, %y : !p
return %add : !p
}
}
Note that polynomial.sub
is canonicalized to polynomial.mul_scalar -1
and polynomial.add
, and further lowering would result in
error: 'arith.remsi' op requires the same type for all operands and results
%add = polynomial.sub %x, %y : !p
^
note: see current operation: %4 = "arith.remsi"(%arg1, %3) : (tensor<1024xi32>, i32) -> tensor<1024xi32>
IR emitted here (one RemSi from add lowering, another from mul_scalar lowering)
%4 = "arith.constant"() <{value = -1 : i32}> : () -> i32
%5 = "tensor.splat"(%4) : (i32) -> tensor<1024xi32>
%6 = "arith.muli"(%1, %5) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi32>, tensor<1024xi32>) -> tensor<1024xi32>
%7 = "arith.constant"() <{value = 33538049 : i32}> : () -> i32
%8 = "arith.remsi"(%1, %7) : (tensor<1024xi32>, i32) -> tensor<1024xi32>
%10 = "arith.constant"() <{value = dense<33538049> : tensor<1024xi26>}> : () -> tensor<1024xi26>
%11 = "arith.extsi"(%3) : (tensor<1024xi32>) -> tensor<1024xi26>
%12 = "arith.extsi"(%8) : (tensor<1024xi32>) -> tensor<1024xi26>
%13 = "arith.addi"(%11, %12) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
%14 = "arith.remsi"(%13, %10) : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
%15 = "arith.trunci"(%14) : (tensor<1024xi26>) -> tensor<1024xi32>