[Backend] Outlined function with const tensor cannot pass CPU simulation
chhzh123 opened this issue · 3 comments
chhzh123 commented
See this example, it works totally fine when the data type is float
or no .outline()
primitive is used.
import heterocl as hcl
import numpy as np
import sys
bs = 4
ic, oc = 6, 16
ih, iw = 8, 8
kh, kw = 3, 3
oh, ow = ih - kh + 1, iw - kw + 1
dtype = hcl.Fixed(26, 20) # hcl.Float()
def test_conv2D_const():
hcl.init(dtype)
A = hcl.placeholder((bs, ic, ih, iw))
np_B = np.random.random((oc, ic, kh, kw))
def conv(A):
rc = hcl.reduce_axis(0, ic)
rh = hcl.reduce_axis(0, kh)
rw = hcl.reduce_axis(0, kw)
F = hcl.const_tensor(np_B, "F", dtype)
B = hcl.compute(
(bs, oc, oh, ow),
lambda n, c, h, w: hcl.sum(
A[n, rc, h + rh, w + rw] * F[c, rc, rh, rw],
axis=[rc, rh, rw],
dtype=dtype,
),
name="B",
dtype=dtype,
)
return B
s = hcl.create_schedule([A], conv)
B = conv.B
LB = s.reuse_at(A, s[B], B.axis[2])
WB = s.reuse_at(LB, s[B], B.axis[3])
s[B].outline()
print(hcl.lower(s))
f = hcl.build(s)
np_A = np.random.random((bs, ic, ih, iw))
np_C = np.zeros((bs, oc, oh, ow), dtype="float")
for n in range(0, bs):
for c in range(0, oc):
for y in range(0, oh):
for x in range(0, ow):
for rc in range(0, ic):
for rh in range(0, kh):
for rw in range(0, kw):
np_C[n][c][y][x] += (
np_A[n][rc][y + rh][x + rw]
* np_B[c][rc][rh][rw]
)
hcl_A = hcl.asarray(np_A, dtype=dtype)
hcl_C = hcl.asarray(np_C, dtype=dtype)
f(hcl_A, hcl_C)
# print(np_C, hcl_C.asnumpy())
assert np.allclose(np_C, hcl_C.asnumpy())
print("Passed!")
if __name__ == "__main__":
test_conv2D_const()
However, when .outline()
is used, it gives the following error.
python3: /scratch/users/hc676/llvm-project/llvm/lib/IR/Instructions.cpp:508: void llvm::CallInst::init(llvm::FunctionType*, llvm::Value*, llvm::ArrayRef<llvm::Value*>, llvm::ArrayRef<llvm::OperandBundleDefT<llvm::Value*> >, const llvm::Twine&): Assertion `(i >= FTy->getNumParams() || FTy->getParamType(i) == Args[i]->getType()) && "Calling a function with a bad signature!"' failed.
#0 0x00007f58fb83d92f PrintStackTraceSignalHandler(void*) Signals.cpp:0:0
#1 0x00007f58fb83b359 SignalHandler(int) Signals.cpp:0:0
#2 0x00007f5918632630 __restore_rt sigaction.c:0:0
#3 0x00007f591828b387 raise (/lib64/libc.so.6+0x36387)
#4 0x00007f591828ca78 abort (/lib64/libc.so.6+0x37a78)
#5 0x00007f59182841a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)
It seems correct from the generated IR.
#set = affine_set<(d0) : (d0 - 2 >= 0)>
module {
memref.global "private" constant @F : memref<16x6x3x3xi64> = dense<"..."> // omitted here
func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<16x6x3x3x!hcl.Fixed<26, 20>>, %arg2: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "___"} {
%c0 = arith.constant 0 : index
%0 = memref.alloc() {name = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
%1 = memref.alloc() {name = "B_reuse_3"} : memref<6x3x3x!hcl.Fixed<26, 20>>
%2 = memref.alloc() {name = "B_reuse_2"} : memref<6x3x8x!hcl.Fixed<26, 20>>
affine.for %arg3 = 0 to 4 {
affine.for %arg4 = 0 to 16 {
affine.for %arg5 = 0 to 8 {
affine.for %arg6 = 0 to 8 {
affine.for %arg7 = 0 to 6 {
%3 = affine.load %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
affine.store %3, %2[%arg7, 0, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
%4 = affine.load %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
affine.store %4, %2[%arg7, 1, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
%5 = affine.load %arg0[%arg3, %arg7, %arg5, %arg6] : memref<4x6x8x8x!hcl.Fixed<26, 20>>
affine.store %5, %2[%arg7, 2, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
} {spatial}
affine.if #set(%arg5) {
affine.for %arg7 = 0 to 6 {
affine.for %arg8 = 0 to 3 {
%3 = affine.load %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
affine.store %3, %1[%arg7, %arg8, 0] : memref<6x3x3x!hcl.Fixed<26, 20>>
%4 = affine.load %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
affine.store %4, %1[%arg7, %arg8, 1] : memref<6x3x3x!hcl.Fixed<26, 20>>
%5 = affine.load %2[%arg7, %arg8, %arg6] : memref<6x3x8x!hcl.Fixed<26, 20>>
affine.store %5, %1[%arg7, %arg8, 2] : memref<6x3x3x!hcl.Fixed<26, 20>>
} {spatial}
} {spatial}
affine.if #set(%arg6) {
%c0_i32 = arith.constant 0 : i32
%3 = hcl.int_to_fixed(%c0_i32) : i32 -> !hcl.Fixed<26, 20>
affine.store %3, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
affine.for %arg7 = 0 to 6 {
affine.for %arg8 = 0 to 3 {
affine.for %arg9 = 0 to 3 {
%5 = affine.load %1[%arg7, %arg8, %arg9] : memref<6x3x3x!hcl.Fixed<26, 20>>
%6 = affine.load %arg1[%arg4, %arg7, %arg8, %arg9] {from = "const_tensor"} : memref<16x6x3x3x!hcl.Fixed<26, 20>>
%7 = "hcl.mul_fixed"(%5, %6) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
%8 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
%9 = "hcl.add_fixed"(%7, %8) : (!hcl.Fixed<26, 20>, !hcl.Fixed<26, 20>) -> !hcl.Fixed<26, 20>
affine.store %9, %0[%c0] {to = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
} {loop_name = "rx_2", reduction}
} {loop_name = "rx_1", reduction}
} {loop_name = "rx_0", reduction}
%4 = affine.load %0[%c0] {from = "sum_rv"} : memref<1x!hcl.Fixed<26, 20>>
affine.store %4, %arg2[%arg3, %arg4, %arg5 - 2, %arg6 - 2] : memref<4x16x6x6x!hcl.Fixed<26, 20>>
}
}
} {loop_name = "w"}
} {loop_name = "h"}
} {loop_name = "c"}
} {loop_name = "n", stage_name = "B"}
return
}
func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
%0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
%1 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
call @Stage_B(%arg0, %0, %1) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<16x6x3x3x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
return %1 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
}
}
zzzDavid commented
Ah, this is because FixedToInteger
pass hasn't implemented transformation on call
operation yet. I'll do it.
chhzh123 commented
I rewrote the pass to generate the following code, where get_global_fixed
is inside the function (without involving the call operation), but why it still cannot work?
func private @Stage_B(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>, %arg1: memref<4x16x6x6x!hcl.Fixed<26, 20>>) attributes {bit, itypes = "__"} {
%c0 = arith.constant 0 : index
%0 = hcl.get_global_fixed @F : memref<16x6x3x3x!hcl.Fixed<26, 20>>
// more computation
return
}
func @top(%arg0: memref<4x6x8x8x!hcl.Fixed<26, 20>>) -> memref<4x16x6x6x!hcl.Fixed<26, 20>> attributes {itypes = "_", otypes = "_"} {
%0 = memref.alloc() {name = "B"} : memref<4x16x6x6x!hcl.Fixed<26, 20>>
call @Stage_B(%arg0, %0) : (memref<4x6x8x8x!hcl.Fixed<26, 20>>, memref<4x16x6x6x!hcl.Fixed<26, 20>>) -> ()
return %0 : memref<4x16x6x6x!hcl.Fixed<26, 20>>
}
zzzDavid commented
I don't think this is related to hcl.get_global_fixed
, it's the function signature transformation when fixed-point type is involved. Let me fix this now