[Binding] Eliminate duplicated operations
Closed this issue · 0 comments
chhzh123 commented
Currently, we view each operation as a new object and rebuild it no matter it was built before, which incurs redundant memory accesses and computation. In the following example, without duplicated operation detection, A[x]
may be loaded several times.
def test_duplicate():
hcl.init(hcl.Int(32))
A = hcl.placeholder((32,))
def func(x):
return (x - 1) * x * (x + 1)
def test(A):
return hcl.compute((32,), lambda i: func(A[i]), "A")
s = hcl.create_schedule([A], test)
print(hcl.lower(s))
From the generated MLIR code, we can see three loads are needed.
module {
func @top(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
%0 = memref.alloc() {name = "A"} : memref<32xi32>
affine.for %arg1 = 0 to 32 {
%1 = affine.load %arg0[%arg1] {from = "compute_0"} : memref<32xi32>
%c1_i32 = arith.constant 1 : i32
%2 = arith.subi %1, %c1_i32 : i32
%3 = affine.load %arg0[%arg1] {from = "compute_0"} : memref<32xi32>
%4 = arith.muli %2, %3 : i32
%5 = affine.load %arg0[%arg1] {from = "compute_0"} : memref<32xi32>
%6 = arith.addi %5, %c1_i32 : i32
%7 = arith.muli %4, %6 : i32
affine.store %7, %0[%arg1] {to = "A"} : memref<32xi32>
} {loop_name = "x", stage_name = "A"}
return %0 : memref<32xi32>
}
}
The original HeteroCL implementation also requires three loads. Obviously, TVM's one-line code cannot reuse the operands without expression folding.
// attr [_top] storage_scope = "global"
allocate _top[int32 * 1]
produce _top {
// attr [0] extern_scope = 0
produce A {
// attr [0] extern_scope = 0
for "stage_name"="A" (i, 0, 32) {
A[i] = int32((int96(((int64(placeholder0[i]) + (int64)-1)*int64(placeholder0[i])))*int96((placeholder0[i] + 1))))
}
}
}
After adding duplication detection, we can reuse the previous results and generate a much clean code.
module {
func @top(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
%0 = memref.alloc() {name = "A"} : memref<32xi32>
affine.for %arg1 = 0 to 32 {
%1 = affine.load %arg0[%arg1] {from = "compute_0"} : memref<32xi32>
%c1_i32 = arith.constant 1 : i32
%2 = arith.subi %1, %c1_i32 : i32
%3 = arith.muli %2, %1 : i32
%4 = arith.addi %1, %c1_i32 : i32
%5 = arith.muli %3, %4 : i32
affine.store %5, %0[%arg1] {to = "A"} : memref<32xi32>
} {loop_name = "x", stage_name = "A"}
return %0 : memref<32xi32>
}
}