[Frontend][Op] Wrong indices built from `hcl.pack`
Closed this issue · 1 comments
zzzDavid commented
Description
Wrong index was generated from hcl.pack
in the frontend.
- Example test case:
test_compute_pack_unpack::test_pack
This example packs four elements from the input tensor into one element in the output tensor.
In the generated IR:
module {
func @top(%arg0: memref<40xi1>) -> memref<10xi4> attributes {bit, extra_itypes = "u", extra_otypes = "u"} {
%0 = memref.alloc() {name = "compute_0", unsigned} : memref<10xi4>
affine.for %arg1 = 0 to 10 {
%1 = memref.alloc() {name = "packed_compute_0", unsigned} : memref<1xi4>
%c0 = arith.constant 0 : index
%c0_i4 = arith.constant {unsigned} 0 : i4
affine.store %c0_i4, %1[0] {to = "packed_compute_0", unsigned} : memref<1xi4>
%2 = hcl.create_loop_handle "loop_0" : !hcl.LoopHandle
affine.for %arg2 = 0 to 4 {
%6 = affine.load %arg0[%arg1] {from = "A", unsigned} : memref<40xi1>
%c0_1 = arith.constant 0 : index
%7 = affine.load %1[0] {from = "packed_compute_0", unsigned} : memref<1xi4>
%c1_i32 = arith.constant 1 : i32
%8 = arith.index_cast %c1_i32 : i32 to index
%9 = arith.muli %8, %arg2 : index
%c1_i32_2 = arith.constant 1 : i32
%10 = arith.index_cast %c1_i32_2 : i32 to index
%11 = arith.addi %arg2, %10 : index
%c1_i32_3 = arith.constant 1 : i32
%12 = arith.index_cast %c1_i32_3 : i32 to index
%13 = arith.muli %12, %11 : index
%c1_i32_4 = arith.constant 1 : i32
%14 = arith.index_cast %c1_i32_4 : i32 to index
%15 = arith.subi %13, %14 : index
hcl.set_slice(%7 : i4, %15, %9, %6 : i1)
affine.store %7, %1[0] {to = "packed_compute_0", unsigned} : memref<1xi4>
} {loop_name = "loop_0"}
%3 = hcl.create_stage_handle "" : !hcl.StageHandle
%c0_0 = arith.constant 0 : index
%4 = affine.load %1[0] {from = "packed_compute_0", unsigned} : memref<1xi4>
%5 = affine.load %1[0] {from = "packed_compute_0", unsigned} : memref<1xi4>
affine.store %5, %0[%arg1] {to = "compute_0"} : memref<10xi4>
} {loop_name = "i0", stage_name = "compute_0"}
hcl.print(%0) {format = "%.0f ", unsigned} : memref<10xi4>
return %0 : memref<10xi4>
}
}
%arg1
is the induction variable of the output packed tensor that ranges from 0-10, %arg2
is the inner loop induction variable ranging from 0-4.
%6
should be loaded from the input memref, with index %arg1 * 4 + %arg2
, instead of %arg1
, otherwise it only reads the first 10 elements.
zzzDavid commented
Fixed by commit: chhzh123/heterocl@3452623