cornell-zhang/heterocl

`@def_` Caveat: No KernelDef in IR Unless Function Is Imported Locally

Opened this issue · 3 comments

Description

For function outlining with @def_ decorator, a local import is required to correctly generate backend code. For HLS backends, the HeteroCL will still run, but we don't get outlined function in generated HLS code. For LLVM backend, the HeteroCL code will fail to run, and throw a segmentation fault, which is difficult to debug.

Minimum Example

main.py

import heterocl as hcl
from submodule import submodule # global import 

target = "llvm" # or "vhls"

hcl.init()

def main():
    A = hcl.placeholder((10,10), name="A")
    B = hcl.placeholder((10,10), name="B")

    def algo(A, B):
        # from submodule import submodule # this local import is required 
        submodule(A)
        hcl.update(B, lambda *args : B[args] + A[args])
    
    s = hcl.create_schedule([A, B], func=algo, name="main")
    f = hcl.build(s, target=target, name="main")

if __name__ == "__main__":
    main()

submodule.py

import heterocl as hcl
from heterocl.dsl import def_

@def_([(10,10)])
def submodule(A):
    A = hcl.compute(A.shape, lambda *args : A[args] + 1)

Running python main.py will get a segmentation fault for LLVM backend.
If we comment out the global import and release the local import, we get correct result.

Cause

When import is done in global scope, the submodule function definition is not run during hcl.create_schedule, so we don't have KernelDef stmt in the IR, only the Call stmt.

Proposal: we can add an IR pass to check if all Call stmts have corresponding KernelDef to detect this problem.

Following-up on this issue -- is there a proper fix planned?

The local import/declaration work-around seem to have additional restrictions. For example:

def f1 (A, B):
      # move this block of code inside the do function and it segfaults.
      @hcl.def_([(10,), (10,), ()])
      def comp(A, B, x):
          with hcl.if_(A[x] > B[x]):
              hcl.return_(A[x])
          hcl.return_(B[x])

      def do (A, B, x):
          return comp (A, B, x)

      hcl.update(B, lambda x: do(A, B, x), "f1")

  A = hcl.placeholder((10,), "A")
  B = hcl.placeholder((10,), "B")

  s = hcl.create_schedule([A, B], func=f1, name="main")
  print(hcl.lower(s))
  f = hcl.build(s)

As is, the code works. But move the def_ inside the do function (i.e., as local as it can get) and it generates a segfault. So it looks like @def_ functions can't be at the top-level nor can it be at some lower/local-level (speculating that this is because it is getting defined within a compute api context). This makes it difficult to define @def_'ed modular building-blocks .

Our current support for @hcl.def is very preliminary. I plan to solve the top-level declaration issue first. For the inner- or lower-level declaration, you suspect correctly. Our original design thought is more close to c++, which also doesn't allow low-level function declaration (unless you use a lambda function). I will think about adding the low-level support later.

Function Outlining API Proposal

As this issue points out, using Python decorator to specify function outlining could lead to scope issue: when submodules defined in a different Python are imported at the global level, the generated outlined function (KernelDef IR node) is not included in the schedule. Besides, adding decorators involves modifying the algorithm specification, which is not a decoupled customization.

Therefore, we propose a new API called s.outline() to specify which stages to outline in a decoupled way. Since this API specifies function outlining on schedule, it won't have the scope issue.

  • The input to this function can be either a single stage or two stages, one as the start stage and the other as the end stage.
  • The start and end stages specify a subgraph in the dataflow graph to be outlined (like host-xcel data placement with .to.

The compiler uses the input stages to extract the subgraph to be outlined. When there’s one stage input, only the input stage is outlined as a function. For the outlined function, the compiler also infers its input and output arguments, builds the function body, and inserts a call operation into the caller function.

Example: 2MM

Outlining single stages

2MM performs two matrix multiplications followed by an element-wise addition. It has three stages: out_AB, out_ABC, and E.

We use s.outline() to outline out_AB and out_ABC as two functions:

A = hcl.placeholder((P, Q), "A")
B = hcl.placeholder((Q, R), "B")
C = hcl.placeholder((R, S), "C")
D = hcl.placeholder((P, S), "D")

def kernel_2mm(A, B, C, D):

    r = hcl.reduce_axis(0, Q, "r")
    out_AB = hcl.compute(
        (P, R),
        lambda x, y: hcl.sum(A[x, r] * B[r, y], axis=r, dtype=dtype),
        name="out_AB",
    )
    k = hcl.reduce_axis(0, R, "k")
    out_ABC = hcl.compute(
        (P, S),
        lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], axis=k, dtype=dtype),
        name="out_ABC",
    )
    E = hcl.compute(
        D.shape,
        lambda x, y: (out_ABC[x, y] + D[x, y]),
        dtype=dtype,
        name="E",
    )
    return E

s = hcl.create_schedule([A, B, C, D], kernel_2mm)
s.outline(kernel_2mm.out_AB)
s.outline(kernel_2mm.out_ABC)

The IR before outlining:

module {
  func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
    // Stage out_AB
    %0 = memref.alloc() : memref<16x18xf32>
    affine.for %arg4 = 0 to 16 {
      affine.for %arg5 = 0 to 18 {
        affine.for %arg6 = 0 to 22 {
	        ...
        } {loop_name = "r"}
      } {loop_name = "y"}
    } {loop_name = "x", stage_name = "out_AB"}
    
    // Stage out_ABC
    %1 = memref.alloc() : memref<16x24xf32>
    affine.for %arg4 = 0 to 16 {
      affine.for %arg5 = 0 to 24 {
        affine.for %arg6 = 0 to 18 {
         ...
        } {loop_name = "k"}
      } {loop_name = "y"}
    } {loop_name = "x", stage_name = "out_ABC"}
    
    // Stage E
    %2 = memref.alloc() : memref<16x24xf32>
    affine.for %arg4 = 0 to 16 {
      affine.for %arg5 = 0 to 24 {
				...
      } {loop_name = "y"}
    } {loop_name = "x", stage_name = "E"}
    return %2 : memref<16x24xf32>
  }
}

The IR after outlining:

module {
  func @Stage_out_AB(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg3: memref<16x18xf32>) ->() {
    ...
  }
  func @Stage_out_ABC(%arg0: memref<16x18xf32>, %arg1: memref<18x24xf32>, %arg3: memref<16x24xf32>) ->() {
    ...
  }
	
  func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
    // Stage out_AB
    %0 = memref.alloc() : memref<16x18xf32>
    call @Stage_out_AB(%arg0, %arg1, %0)
    
    // Stage out_ABC
   %1 = memref.alloc() : memref<16x24xf32>
    call @Stage_out_AB(%0, %arg2, %1)
    
    // Stage E
    %2 = memref.alloc() : memref<16x24xf32>
    affine.for %arg4 = 0 to 16 {
      affine.for %arg5 = 0 to 24 {
				...
      } {loop_name = "y"}
    } {loop_name = "x", stage_name = "E"}
    return %2 : memref<16x24xf32>
  }
}

Outlining multiple stages as a single function

We use s.outline to specify the subgraph we would like to outline as a function:

s = hcl.create_schedule([A, B, C, D], kernel_2mm)
s.outline(kernel_2mm.out_AB, kernel_2mm.out_ABC)

The IR after outlining:

module {
  func @Stage_outAB_outABC(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg3: memref<16x18xf32>, %arg4: memref<18x24xf32>, %arg5: memref<16x24xf32>) ->() {
		...
  }
	
  func @top(%arg0: memref<16x22xf32>, %arg1: memref<22x18xf32>, %arg2: memref<18x24xf32>, %arg3: memref<16x24xf32>) -> memref<16x24xf32> {
    // Stage out_AB and out_ABC
    %0 = memref.alloc() : memref<16x18xf32>
    %1 = memref.alloc() : memref<16x24xf32>
    call @Stage_outAB_outABC(%arg0, %arg1, %0, %arg2, %1)
    
    // Stage E
    %2 = memref.alloc() : memref<16x24xf32>
    affine.for %arg4 = 0 to 16 {
      affine.for %arg5 = 0 to 24 {
				...
      } {loop_name = "y"}
    } {loop_name = "x", stage_name = "E"}
    return %2 : memref<16x24xf32>
  }
}