cornell-zhang/hcl-dialect

Incorrect multiplication result bitwidth

jcasas00 opened this issue · 2 comments

def test_mul_bits():
    hcl.init()
    rshape = (1,)
    _a = 1<<16
    _b = 1<<16
    def kernel():
        r = hcl.compute(rshape, lambda _:0, dtype=hcl.UInt(32))
        a = hcl.scalar(_a, "a", dtype=hcl.UInt(32)).v
        b = hcl.scalar(_b, "b", dtype=hcl.UInt(32)).v
        tmp = hcl.scalar((a*b)>>16, "foo", dtype=hcl.UInt(32))
        r[0] = tmp.v
        return r
    #
    s = hcl.create_schedule([], kernel)
    print(hcl.lower(s))
    hcl_res = hcl.asarray(np.zeros(rshape, dtype=np.uint32), dtype=hcl.UInt(32))
    f = hcl.build(s)
    f(hcl_res)
    np_res = hcl_res.asnumpy()
    golden = np.asarray([(_a*_b)>>16], dtype=np.uint32)
    assert np.array_equal(golden, np_res), f"golden {golden} != np_res {np_res}"

The above code fails and it looks like it is due to the intermediate result of (a*b) does not have sufficient bitwidth. The IR looks like:

%2 = affine.load %1[0] {from = "a", unsigned} : memref<1xi32>
%3 = memref.alloc() {name = "b", unsigned} : memref<1xi32>
affine.store %c65536_i32, %3[0] {to = "b", unsigned} : memref<1xi32>
%4 = affine.load %3[0] {from = "b", unsigned} : memref<1xi32>
%5 = arith.muli %2, %4 {unsigned} : i32

The result of the multiplication should be 64 bits.

In comparison, the IR with the main branch version generates:

  foo[x] = uint32(shift_right((uint64(a[0])*uint64(b[0])), 16))

Auto casting based on result bitwidth may complicate the code, lead to possible excessive resource usage, and cause lots of issues when we have SelectOp, so we are still discussing whether we want this feature at the IR level or directly tackle overflow in runtime.

This issue has been resolved after ast branch is merged. In the new frontend we introduced type inference engine with extensible typing rules. The above test case now has correct result type.