Incorrect multiplication result bitwidth
jcasas00 opened this issue · 2 comments
jcasas00 commented
def test_mul_bits():
hcl.init()
rshape = (1,)
_a = 1<<16
_b = 1<<16
def kernel():
r = hcl.compute(rshape, lambda _:0, dtype=hcl.UInt(32))
a = hcl.scalar(_a, "a", dtype=hcl.UInt(32)).v
b = hcl.scalar(_b, "b", dtype=hcl.UInt(32)).v
tmp = hcl.scalar((a*b)>>16, "foo", dtype=hcl.UInt(32))
r[0] = tmp.v
return r
#
s = hcl.create_schedule([], kernel)
print(hcl.lower(s))
hcl_res = hcl.asarray(np.zeros(rshape, dtype=np.uint32), dtype=hcl.UInt(32))
f = hcl.build(s)
f(hcl_res)
np_res = hcl_res.asnumpy()
golden = np.asarray([(_a*_b)>>16], dtype=np.uint32)
assert np.array_equal(golden, np_res), f"golden {golden} != np_res {np_res}"
The above code fails and it looks like it is due to the intermediate result of (a*b) does not have sufficient bitwidth. The IR looks like:
%2 = affine.load %1[0] {from = "a", unsigned} : memref<1xi32>
%3 = memref.alloc() {name = "b", unsigned} : memref<1xi32>
affine.store %c65536_i32, %3[0] {to = "b", unsigned} : memref<1xi32>
%4 = affine.load %3[0] {from = "b", unsigned} : memref<1xi32>
%5 = arith.muli %2, %4 {unsigned} : i32
The result of the multiplication should be 64 bits.
In comparison, the IR with the main branch version generates:
foo[x] = uint32(shift_right((uint64(a[0])*uint64(b[0])), 16))
chhzh123 commented
Auto casting based on result bitwidth may complicate the code, lead to possible excessive resource usage, and cause lots of issues when we have SelectOp, so we are still discussing whether we want this feature at the IR level or directly tackle overflow in runtime.
zzzDavid commented
This issue has been resolved after ast branch is merged. In the new frontend we introduced type inference engine with extensible typing rules. The above test case now has correct result type.