cornell-zhang/hcl-dialect

[Backend] Outlined function with const tensor cannot pass CPU simulation

chhzh123 opened this issue · 3 comments

See this example, it works totally fine when the data type is float or no .outline() primitive is used.

import heterocl as hcl
import numpy as np
import sys

bs = 4
ic, oc = 6, 16
ih, iw = 8, 8
kh, kw = 3, 3
oh, ow = ih - kh + 1, iw - kw + 1
dtype = hcl.Fixed(26, 20) # hcl.Float()

def test_conv2D_const():
    hcl.init(dtype)
    A = hcl.placeholder((bs, ic, ih, iw))
    np_B = np.random.random((oc, ic, kh, kw))

    def conv(A):
        rc = hcl.reduce_axis(0, ic)
        rh = hcl.reduce_axis(0, kh)
        rw = hcl.reduce_axis(0, kw)
        F = hcl.const_tensor(np_B, "F", dtype)
        B = hcl.compute(
            (bs, oc, oh, ow),
            lambda n, c, h, w: hcl.sum(
                    A[n, rc, h + rh, w + rw] * F[c, rc, rh, rw],
                    axis=[rc, rh, rw],
                    dtype=dtype,
                ),
            name="B",
            dtype=dtype,
        )
        return B

    s = hcl.create_schedule([A], conv)
    B = conv.B
    LB = s.reuse_at(A, s[B], B.axis[2])
    WB = s.reuse_at(LB, s[B], B.axis[3])
    s[B].outline()

    print(hcl.lower(s))
    f = hcl.build(s)

    np_A = np.random.random((bs, ic, ih, iw))
    np_C = np.zeros((bs, oc, oh, ow), dtype="float")

    for n in range(0, bs):
        for c in range(0, oc):
            for y in range(0, oh):
                for x in range(0, ow):
                    for rc in range(0, ic):
                        for rh in range(0, kh):
                            for rw in range(0, kw):
                                np_C[n][c][y][x] += (
                                    np_A[n][rc][y + rh][x + rw]
                                    * np_B[c][rc][rh][rw]
                                )

    hcl_A = hcl.asarray(np_A, dtype=dtype)
    hcl_C = hcl.asarray(np_C, dtype=dtype)

    f(hcl_A, hcl_C)
    # print(np_C, hcl_C.asnumpy())

    assert np.allclose(np_C, hcl_C.asnumpy())
    print("Passed!")


if __name__ == "__main__":
    test_conv2D_const()

However, when .outline() is used, it gives the following error.

python3: /scratch/users/hc676/llvm-project/llvm/lib/IR/Instructions.cpp:508: void llvm::CallInst::init(llvm::FunctionType*, llvm::Value*, llvm::ArrayRef<llvm::Value*>, llvm::ArrayRef<llvm::OperandBundleDefT<llvm::Value*> >, const llvm::Twine&): Assertion `(i >= FTy->getNumParams() || FTy->getParamType(i) == Args[i]->getType()) && "Calling a function with a bad signature!"' failed.
 #0 0x00007f58fb83d92f PrintStackTraceSignalHandler(void*) Signals.cpp:0:0
 #1 0x00007f58fb83b359 SignalHandler(int) Signals.cpp:0:0
 #2 0x00007f5918632630 __restore_rt sigaction.c:0:0
 #3 0x00007f591828b387 raise (/lib64/libc.so.6+0x36387)
 #4 0x00007f591828ca78 abort (/lib64/libc.so.6+0x37a78)
 #5 0x00007f59182841a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)

It seems correct from the generated IR.

#set = affine_set<(d0) : (d0 - 2 >= 0)>
module {
  memref.global "private" constant @F : memref<16x6x3x3xi64> = dense<"..."> // omitted here
  func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<16x6x3x3x!hcl.Fixed<26, 20>>, %arg2: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "___"} {
    %c0 = arith.constant 0 : index
    %0 = memref.alloc() {name = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
    %1 = memref.alloc() {name = "B_reuse_3"} : memref<6x3x3x!hcl.Fixed<26, 20>>
    %2 = memref.alloc() {name = "B_reuse_2"} : memref<6x3x8x!hcl.Fixed<26, 20>>
    affine.for %arg3 = 0 to 4 {
      affine.for %arg4 = 0 to 16 {
        affine.for %arg5 = 0 to 8 {
          affine.for %arg6 = 0 to 8 {
            affine.for %arg7 = 0 to 6 {
              %3 = affine.load %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              affine.store %3, %2[%arg7, 0, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              %4 = affine.load %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              affine.store %4, %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
              %5 = affine.load %arg0[%arg3, %arg7, %arg5, %arg6] : memref<4x6x8x8x!hcl.Fixed<26, 20>>
              affine.store %5, %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
            } {spatial}
            affine.if #set(%arg5) {
              affine.for %arg7 = 0 to 6 {
                affine.for %arg8 = 0 to 3 {
                  %3 = affine.load %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  affine.store %3, %1[%arg7, %arg8, 0] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  %4 = affine.load %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  affine.store %4, %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
                  %5 = affine.load %2[%arg7, %arg8, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
                  affine.store %5, %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
                } {spatial}
              } {spatial}
              affine.if #set(%arg6) {
                %c0_i32 = arith.constant 0 : i32
                %3 = hcl.int_to_fixed(%c0_i32) : i32 -> !hcl.Fixed<26, 20>
                affine.store %3, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                affine.for %arg7 = 0 to 6 {
                  affine.for %arg8 = 0 to 3 {
                    affine.for %arg9 = 0 to 3 {
                      %5 = affine.load %1[%arg7, %arg8, %arg9] : memref<6x3x3x!hcl.Fixed<26, 20>>
                      %6 = affine.load %arg1[%arg4, %arg7, %arg8, %arg9] {from = "const_tensor"} : memref<16x6x3x3x!hcl.Fixed<26, 20>>
                      %7 = "hcl.mul_fixed"(%5, %6) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
                      %8 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                      %9 = "hcl.add_fixed"(%7, %8) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
                      affine.store %9, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                    } {loop_name = "rx_2", reduction}
                  } {loop_name = "rx_1", reduction}
                } {loop_name = "rx_0", reduction}
                %4 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
                affine.store %4, %arg2[%arg3, %arg4, %arg5 - 2, %arg6 - 2] : memref<4x16x6x6x!hcl.Fixed<26, 20>>
              }
            }
          } {loop_name = "w"}
        } {loop_name = "h"}
      } {loop_name = "c"}
    } {loop_name = "n", stage_name = "B"}
    return
  }
  func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
    %0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
    %1 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
    call @Stage_B(%arg0, %0, %1) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<16x6x3x3x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
    return %1 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
  }
}

Ah, this is because FixedToInteger pass hasn't implemented transformation on call operation yet. I'll do it.

I rewrote the pass to generate the following code, where get_global_fixed is inside the function (without involving the call operation), but why it still cannot work?

func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "__"} {
    %c0 = arith.constant 0 : index
    %0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
    // more computation
    return
  }
  func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
    %0 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
    call @Stage_B(%arg0, %0) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
    return %0 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
  }

I don't think this is related to hcl.get_global_fixed, it's the function signature transformation when fixed-point type is involved. Let me fix this now