[Binding] ReduceOp should load element from buffer first
Closed this issue · 0 comments
chhzh123 commented
The current Python binding of ReduceOp returns the reduction variable (rv). However, since we actually create a size-1 buffer for the rv, it is not a register and should be read before other operations. Otherwise, it may cause the following error. (Example: sum(A[i, k] * B[k, j]) + 10
)
error: 'arith.addi' op requires the same type for all operands and results
// Verification failed, printing generic form
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<() -> (0)>
#map3 = affine_map<() -> (32)>
"builtin.module"() ({
"builtin.func"() ({
^bb0(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>):
%0 = "memref.alloc"() {name = "C", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<32x32xi32>
%1 = "hcl.create_stage_handle"() {stage_name = "C"} : () -> !hcl.StageHandle
%2 = "hcl.create_loop_handle"() {loop_name = "i"} : () -> !hcl.LoopHandle
%3 = "hcl.create_loop_handle"() {loop_name = "j"} : () -> !hcl.LoopHandle
%4 = "memref.alloc"() {name = "C", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<32x32xi32>
"affine.for"() ({
^bb0(%arg2: index):
"affine.for"() ({
^bb0(%arg3: index):
%5 = "memref.alloc"() {name = "sum_rv", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<1xi32>
%6 = "arith.constant"() {value = 0 : index} : () -> index
%7 = "arith.constant"() {value = 0 : i32} : () -> i32
"affine.store"(%7, %5, %6) {map = #map0, to = "sum_rv"} : (i32, memref<1xi32>, index) -> ()
"affine.for"() ({
^bb0(%arg4: index):
%10 = "affine.load"(%arg0, %arg2, %arg4) {from = "A", map = #map1} : (memref<32x32xi32>, index, index) -> i32
%11 = "affine.load"(%arg1, %arg4, %arg3) {from = "B", map = #map1} : (memref<32x32xi32>, index, index) -> i32
%12 = "arith.muli"(%10, %11) : (i32, i32) -> i32
%13 = "affine.load"(%5, %6) {from = "sum_rv", map = #map0} : (memref<1xi32>, index) -> i32
%14 = "arith.addi"(%12, %13) : (i32, i32) -> i32
"affine.store"(%14, %5, %6) {map = #map0, to = "sum_rv"} : (i32, memref<1xi32>, index) -> ()
"affine.yield"() : () -> ()
}) {loop_name = "k", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
%8 = "arith.constant"() {value = 10 : i32} : () -> i32
%9 = "arith.addi"(%5, %8) : (memref<1xi32>, i32) -> memref<1xi32>
"affine.store"(%9, %4, %arg2, %arg3) {map = #map1, to = "C"} : (memref<1xi32>, memref<32x32xi32>, index, index) -> ()
"affine.yield"() : () -> ()
}) {loop_name = "j", lower_bound = #map2, step = 1 : i32, upper_bound = #map3} : () -> ()
"affine.yield"() : () -> ()
}) {loop_name = "i", lower_bound = #map2, stage_name = "C", step = 1 : i32, upper_bound = #map3} : () -> ()
"std.return"(%4) : (memref<32x32xi32>) -> ()
}) {sym_name = "top", type = (memref<32x32xi32>, memref<32x32xi32>) -> memref<32x32xi32>} : () -> ()
}) : () -> ()
where %9 = "arith.addi"(%5, %8) : (memref<1xi32>, i32) -> memref<1xi32>
tries to add the reduction variable but should load the element first.