Result error when using `hcl.max`
sqPoseidon opened this issue · 1 comments
sqPoseidon commented
I'm using the hcl-mlir
branch and want to implement the pooling with hcl.max
.
I got the wrong results with several different input types. Here's my test case:
import heterocl as hcl
import numpy as np
in_shape = (2, 8)
out_shape = (1, 4)
stride = 2
kernel = 2
# test_dtype = hcl.Int(10)
# test_dtype = hcl.Int(8)
# test_dtype = hcl.Int(6)
test_dtype = hcl.Fixed(12, 6)
def max_pool(data):
h = hcl.reduce_axis(0, kernel)
w = hcl.reduce_axis(0, kernel)
return hcl.compute(
out_shape,
lambda hh, ww: hcl.max(data[stride * hh + h, stride * ww + w],
axis=[h, w], dtype=test_dtype),
name="max_pool",
dtype=test_dtype
)
A = hcl.placeholder(in_shape, "A", dtype=test_dtype)
s = hcl.create_schedule([A], max_pool)
f = hcl.build(s)
a = hcl.asarray(np.zeros(in_shape), dtype=test_dtype)
# a = hcl.asarray(np.random.randint(0, 10, size=in_shape), dtype=test_dtype)
b = hcl.asarray(np.zeros(out_shape), dtype=test_dtype)
f(a, b)
a_np = a.asnumpy()
print("a_np: ", a_np)
b_np = b.asnumpy()
print("test_dtype: ", test_dtype, ", b_np: ", b_np, flush=True)
I use a 2d zero tensor as the input. The data type is specified as test_dtype
.
For the above test_dtypes:
test_dtype = hcl.Int(10) # wrong results: [[193 193 193 193]]
test_dtype = hcl.Int(8) # correct results: [[0 0 0 0]]
test_dtype = hcl.Int(6) # wrong results: [[1 1 1 1]]
For test_dtype = hcl.Fixed(12, 6)
, the error message is:
python: xxx/heterocl-mlir/hcl-dialect-prototype/llvm-project/llvm/include/llvm/ADT/ArrayRef.h:442: T& llvm::MutableArrayRef<T>::operator[](size_t) const [with T = mlir::OpOperand; size_t = long unsigned int]: Assertion `Index < this->size()
&& "Invalid index!"' failed.
#0 0x00007f2044b04b7f PrintStackTraceSignalHandler(void*) Signals.cpp:0:0
#1 0x00007f2044b0255c SignalHandler(int) Signals.cpp:0:0
#2 0x00007f226e633630 __restore_rt sigaction.c:0:0
#3 0x00007f226e28c387 raise (/lib64/libc.so.6+0x36387)
#4 0x00007f226e28da78 abort (/lib64/libc.so.6+0x37a78)
#5 0x00007f226e2851a6 __assert_fail_base (/lib64/libc.so.6+0x2f1a6)
#6 0x00007f226e285252 (/lib64/libc.so.6+0x2f252)
#7 0x00007f20448ee455 llvm::MutableArrayRef<mlir::OpOperand>::operator[](unsigned long) const xxx/heterocl-mlir/hcl-dialect-prototype/llvm-project/llvm/include/llvm/ADT/ArrayRef.h:442:7
#8 0x00007f20448e9e26 mlir::Operation::getOpOperand(unsigned int) xxx/heterocl-mlir/hcl-dialect-prototype/llvm-project/mlir/include/mlir/IR/Operation.h:307:3
......
......
Please help check this problem.