cornell-zhang/hcl-dialect

Unexpected extra lines getting printed out with hcl.print

Closed this issue · 5 comments

def test_print_extra_output():
    hcl.init()
    def kernel():
        x = hcl.scalar(1, "x", dtype=hcl.UInt(8))
        y = hcl.scalar(2, "y", dtype=hcl.UInt(8))
        hcl.print(x.v, "x = %d\n")
        hcl.print(y.v, "y = %d\n")

        r = hcl.compute((1,), lambda _:0, dtype=hcl.UInt(32))
        return r
    s = hcl.create_schedule([], kernel)
    hcl_res = hcl.asarray(np.zeros((1,), dtype=np.uint32), dtype=hcl.UInt(32))
    f = hcl.build(s)
    f(hcl_res)

When I run this code, I get:

hcl-dialect-prototype/build/tools/hcl/python_packages/hcl_core/hcl_mlir/exceptions.py:70: RuntimeWarning:
[API] LLVM_BUILD_DIR is not set, print memref feature is not available.
warnings.warn(self.message, category=self.category)
x = 1
y = 2
x = 0 <---- not expecting this output

If either of the hcl.print calls is commented out, that 3rd line is not generated.

Seems to be an issue with the call operation. To reproduce this issue, the IR is:

// RUN: hcl-opt --lower-print-ops --jit %s

module {
  func.func @top() -> () {
    %x = arith.constant 0 : i32
    hcl.print(%x) {format="x: %d \n"} : i32
    %y = arith.constant 1 : i32
    hcl.print(%y) {format="y: %d \n"} : i32
    return
  }
}

And the lowered LLVM IR is:

module {
  llvm.mlir.global internal constant @frmt_spec1("y: %.0f \0A")
  llvm.mlir.global internal constant @frmt_spec0("x: %.0f \0A")
  llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
  llvm.func @top() {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.addressof @frmt_spec0 : !llvm.ptr<array<9 x i8>>
    %2 = llvm.mlir.constant(0 : index) : i64
    %3 = llvm.getelementptr %1[%2, %2] : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
    %4 = llvm.mlir.constant(0 : i64) : i64
    %5 = llvm.sitofp %4 : i64 to f64
    %6 = llvm.call @printf(%3, %5) : (!llvm.ptr<i8>, f64) -> i32
    %7 = llvm.mlir.constant(1 : i32) : i32
    %8 = llvm.mlir.addressof @frmt_spec1 : !llvm.ptr<array<9 x i8>>
    %9 = llvm.mlir.constant(0 : index) : i64
    %10 = llvm.getelementptr %8[%9, %9] : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
    %11 = llvm.mlir.constant(1 : i64) : i64
    %12 = llvm.sitofp %11 : i64 to f64
    %13 = llvm.call @printf(%10, %12) : (!llvm.ptr<i8>, f64) -> i32
    llvm.return
  }
}

The lowered LLVM dialect IR seems correct, looking into what's causing this issue

I figured it out, it's the alignment issue of global string, e.g. llvm.mlir.global internal constant @frmt_spec1("y: %.0f \0A"). llvm.getelementptr would get the next global string, so there's an extra print.

Still seeing extra outputs.

    def kernel():
        z = hcl.scalar(3, "z", dtype=hcl.UInt(16))
        hcl.print((z.v), "zz=%d ")
        hcl.print((z.v,z.v), "aaaaaaaaaaaaa=%d bbbbbbbb=%d")
        hcl.print((), "    \n")
        #
        r = hcl.compute((1,), lambda _:0, dtype=hcl.UInt(32))
        return r

generates the output:

zz=3 �aaaaaaaaaaaaa=3 bbbbbbbb=3zz=0 �

Note the extra characters after the first "zz=3" and "zz=0" appears at the end.
Seems like there's some overflow happening with the lines that doesn't end with a \n.

LLVM string requires a terminator \00, the same as the \0 in C string. I encounter the same issue with reading/writing files, a more detailed description is here: https://github.com/zzzDavid/hcl-debug/tree/main/read_write_file

This issue will be closed after test cases are added.