cornell-zhang/heterocl

expand func op pass

andrewcaiuuu opened this issue · 0 comments

Hello, I am trying to write a pass to give each compute op a func op, it is generating IRs that fail to lower.
The pass looks like this, is there anything obviously incorrect?

from ..ast import ast
from .pass_manager import Pass
from hcl_mlir.exceptions import *

class ExpandFunc(Pass):
    """ Convert all funcop into nested funcop """
    def __init__(self):
        super().__init__("expand_func")
        self._ast = None
        self.subfuncs = []

    def visit(self, op):
        if isinstance(op, ast.FuncOp) and op.name == "top":
            self.expand_func(op)
            # print("SUBFUNCS: ", self.subfuncs)
            # print("ORIGINAL BODY: ", op.body)
            op.body = []
            for subfunc in self.subfuncs:
                call_op = ast.CallOp(subfunc.name, subfunc.args, subfunc.return_tensors, subfunc.loc)
                op.body.append(call_op)
                
    
    def apply(self, _ast):
        """Pass entry point"""
        self._ast = _ast
        for op in _ast.region:
            self.visit(op)
        return _ast

    def expand_func(self, scope):
        i = 0
        for op in scope.body:
            # print("EXPAND_FUNC GOT OP: ", op)
            if isinstance(op, ast.ComputeOp):
                lower_func_op = ast.FuncOp(f"sub_func{i}", op.input_tensors, [op], op.loc)
                lower_func_op.level = 1
                self.update_level(lower_func_op)
                self._ast.region.insert(1, lower_func_op)
                self.subfuncs.append(lower_func_op)
                i += 1
        return

This is the module used for testing right now:

A = hcl.placeholder((10,), "A")
def kernel(A):
    B = hcl.compute((10,), lambda x: A[x])
    return B

s = hcl.create_schedule([A], kernel)
print(hcl.lower(s))

which results in this IR, the bolded line is wrong:


#map0 = affine_map<(d0) -> (d0)>  
#map1 = affine_map<() -> (0)>  
#map2 = affine_map<() -> (10)>  
"builtin.module"() ({  
  "func.func"() ({  
  ^bb0(%arg0: memref<10xi32>):  
    "func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()  
    %0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
    "func.return"(%0) : (memref<10xi32>) -> ()
  }) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
  "func.func"() ({
  ^bb0(%arg0: memref<10xi32>):
    "affine.for"() ({
    ^bb0(%arg1: index):
      %0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
     "affine.store"(%0, %0, %arg1) {map = #map0, to = "tensor_1"} : (i32, memref<10xi32>, index) -> ()  
      "affine.yield"() : () -> ()  
    }) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
    "func.return"() : () -> ()
  }) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> () 

it passes when changed to:


#map0 = affine_map<(d0) -> (d0)>  
#map1 = affine_map<() -> (0)>  
#map2 = affine_map<() -> (10)>  
"builtin.module"() ({  
  "func.func"() ({  
  ^bb0(%arg0: memref<10xi32>):  
    "func.call"(%arg0) {callee = @sub_func0} : (memref<10xi32>) -> ()  
    %0 = "memref.alloc"() {name = "tensor_1", operand_segment_sizes = dense<0> : vector<2xi32>} : () -> memref<10xi32>
    "func.return"(%0) : (memref<10xi32>) -> ()
  }) {function_type = (memref<10xi32>) -> memref<10xi32>, itypes = "s", otypes = "s", sym_name = "top"} : () -> ()
  "func.func"() ({
  ^bb0(%arg0: memref<10xi32>):
    "affine.for"() ({
    ^bb0(%arg1: index):
      %0 = "affine.load"(%arg0, %arg1) {from = "A", map = #map0} : (memref<10xi32>, index) -> i32
      "affine.store"(%0, %arg0, %arg1) {map = #map0, to = "A"} : (i32, memref<10xi32>, index) -> () 
      "affine.yield"() : () -> ()  
    }) {loop_name = "x", lower_bound = #map1, op_name = "tensor_1", step = 1 : i32, upper_bound = #map2} : () -> ()
    "func.return"() : () -> ()
  }) {function_type = (memref<10xi32>) -> (), itypes = "s", otypes = "", sym_name = "sub_func0"} : () -> ()
}) : () -> () 

I have my code in a fork linked here
https://github.com/andrewcaiuuu/heterocl
The module shown above is in the playground folder named simple.py

Thank you very much.