google/heir

polynomial-to-standard: error on integer width for poly.add when ring coefficientModulus width is less than coefficient width

Opened this issue · 3 comments

When playing around tests/bgv/to_polynomial.mlir with additional -convert-elementwise-to-affine -polynomial-to-standard, error happened

error: 'arith.extsi' op operand type 'tensor<1024xi32>' and result type 'tensor<1024xi26>' are cast incompatible
    %mul = bgv.mul %x, %y : (!ct1, !ct1) -> !ct2
           ^
note: see current operation: %19 = "arith.extsi"(%13) : (tensor<1024xi32>) -> tensor<1024xi26>

Reduced to a minimal working example with argument -polynomial-to-standard, where changing the line to !p = !p2 resolves the problem:

#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>

!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>

!p = !p1

module {
  func.func @polymul(%x : !p, %y : !p) -> (!p) {
    %add = polynomial.add %x, %y : !p
    return %add : !p
  }
}

The related code is the following

bool needToExtend =
mod.zextOrTrunc(coeffTypeMod.getBitWidth()).ult(coeffTypeMod);
if (!needToExtend) {
auto result = b.create<arith::AddIOp>(adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
// The arithmetic may spill into higher bit width, so start by extending
// all the types to the smallest bit width that can contain them all.
unsigned nextHigherBitWidth = (mod - 1).getActiveBits() + 1;
auto modIntType = rewriter.getIntegerType(nextHigherBitWidth);
auto modIntTensorType = RankedTensorType::get(type.getShape(), modIntType);
auto cmod = b.create<arith::ConstantOp>(DenseIntElementsAttr::get(
modIntTensorType, {mod.zextOrTrunc(nextHigherBitWidth)}));
auto signExtensionLhs =
b.create<arith::ExtSIOp>(modIntTensorType, adaptor.getLhs());
auto signExtensionRhs =
b.create<arith::ExtSIOp>(modIntTensorType, adaptor.getRhs());

It assumes either coefficientModulus is a power-of-two or its width is the same as coefficientType.

Maybe we should check that coefficientType has the same width as coefficientModulus if it is a prime.

Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified? Currently extsi/extui is everywhere.

cc @AlexanderViand @inbelic on this issue

Also, should we consider lowering it to mod_arith so that polynomial-to-standard can be simplified?

I think so - but I will need to catch up on this particular pass pipeline.

Thanks for bringing this up! I’ll add it to the list of “issues that show why we really need formal verification around mod_arith stuff”!

I agree @ switching this pipeline over to mod_arith would be the best thing to make sure at least these kinds of issues only occur in one place xD

polynomial.mul_scalar also has an erroneous lowering (the test only covers power-of-two branch):

Similar code as above, and would report error no matter !p1 or !p2:

#my_poly = #polynomial.int_polynomial<1 + x**1024>
#ring1 = #polynomial.ring<coefficientType = i32, coefficientModulus = 33538049 : i32, polynomialModulus=#my_poly>
#ring2 = #polynomial.ring<coefficientType = i25, coefficientModulus = 33538049 : i25, polynomialModulus=#my_poly>

!p1 = !polynomial.polynomial<ring = #ring1>
!p2 = !polynomial.polynomial<ring = #ring2>

!p = !p1

module {
  func.func @polymul(%x : !p, %y : !p) -> (!p) {
    %add = polynomial.sub %x, %y : !p
    return %add : !p
  }
}

Note that polynomial.sub is canonicalized to polynomial.mul_scalar -1 and polynomial.add, and further lowering would result in

error: 'arith.remsi' op requires the same type for all operands and results
    %add = polynomial.sub %x, %y : !p
           ^
note: see current operation: %4 = "arith.remsi"(%arg1, %3) : (tensor<1024xi32>, i32) -> tensor<1024xi32>

IR emitted here (one RemSi from add lowering, another from mul_scalar lowering)

  %4 = "arith.constant"() <{value = -1 : i32}> : () -> i32
  %5 = "tensor.splat"(%4) : (i32) -> tensor<1024xi32>
  %6 = "arith.muli"(%1, %5) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi32>, tensor<1024xi32>) -> tensor<1024xi32>
  %7 = "arith.constant"() <{value = 33538049 : i32}> : () -> i32
  %8 = "arith.remsi"(%1, %7) : (tensor<1024xi32>, i32) -> tensor<1024xi32>
  %10 = "arith.constant"() <{value = dense<33538049> : tensor<1024xi26>}> : () -> tensor<1024xi26>
  %11 = "arith.extsi"(%3) : (tensor<1024xi32>) -> tensor<1024xi26>
  %12 = "arith.extsi"(%8) : (tensor<1024xi32>) -> tensor<1024xi26>
  %13 = "arith.addi"(%11, %12) <{overflowFlags = #arith.overflow<none>}> : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
  %14 = "arith.remsi"(%13, %10) : (tensor<1024xi26>, tensor<1024xi26>) -> tensor<1024xi26>
  %15 = "arith.trunci"(%14) : (tensor<1024xi26>) -> tensor<1024xi32>