expand func op pass
andrewcaiuuu opened this issue · 0 comments
andrewcaiuuu commented
Hello, I am trying to write a pass to give each compute op a func op, it is generating IRs that fail to lower.
The pass looks like this, is there anything obviously incorrect?
from ..ast import ast
from .pass_manager import Pass
from hcl_mlir.exceptions import *
class ExpandFunc(Pass):
""" Convert all funcop into nested funcop """
def __init__(self):
super().__init__("expand_func")
self._ast = None
self.subfuncs = []
def visit(self, op):
if isinstance(op, ast.FuncOp) and op.name == "top":
self.expand_func(op)
# print("SUBFUNCS: ", self.subfuncs)
# print("ORIGINAL BODY: ", op.body)
op.body = []
for subfunc in self.subfuncs:
call_op = ast.CallOp(subfunc.name, subfunc.args, subfunc.return_tensors, subfunc.loc)
op.body.append(call_op)
def apply(self, _ast):
"""Pass entry point"""
self._ast = _ast
for op in _ast.region:
self.visit(op)
return _ast
def expand_func(self, scope):
i = 0
for op in scope.body:
# print("EXPAND_FUNC GOT OP: ", op)
if isinstance(op, ast.ComputeOp):
lower_func_op = ast.FuncOp(f"sub_func{i}", op.input_tensors, [op], op.loc)
lower_func_op.level = 1
self.update_level(lower_func_op)
self._ast.region.insert(1, lower_func_op)
self.subfuncs.append(lower_func_op)
i += 1
return
This is the module used for testing right now:
A = hcl.placeholder((10,), "A")
def kernel(A):
B = hcl.compute((10,), lambda x: A[x])
return B
s = hcl.create_schedule([A], kernel)
print(hcl.lower(s))
which results in this IR, the bolded line is wrong:
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<() -> (0)>
#map2 = affine_map<() -> (10)>
"builtin.module"() ({
"func.func"() ({
^bb0(%arg0: memref<10xi32>):
"func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()
%0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
"func.return"(%0) : (memref<10xi32>) -> ()
}) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
"func.func"() ({
^bb0(%arg0: memref<10xi32>):
"affine.for"() ({
^bb0(%arg1: index):
%0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
"affine.store"(%0, %0, %arg1) {map = #map0, to = "tensor_1"} : (i32, memref<10xi32>, index) -> ()
"affine.yield"() : () -> ()
}) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
"func.return"() : () -> ()
}) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> ()
it passes when changed to:
#map0 = affine_map<(d0) -> (d0)>
#map1 = affine_map<() -> (0)>
#map2 = affine_map<() -> (10)>
"builtin.module"() ({
"func.func"() ({
^bb0(%arg0: memref<10xi32>):
"func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()
%0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
"func.return"(%0) : (memref<10xi32>) -> ()
}) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
"func.func"() ({
^bb0(%arg0: memref<10xi32>):
"affine.for"() ({
^bb0(%arg1: index):
%0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
"affine.store"(%0, %arg0, %arg1) {map = #map0, to = "A"} : (i32, memref<10xi32>, index) -> ()
"affine.yield"() : () -> ()
}) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
"func.return"() : () -> ()
}) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> ()
I have my code in a fork linked here
https://github.com/andrewcaiuuu/heterocl
The module shown above is in the playground folder named simple.py
Thank you very much.