cornell-zhang/allo

[BUG] Loop carried dependences should be SSA values not memory operations

Opened this issue · 4 comments

Describe the bug
Neither writing kernels with primitives like matmul() or using allo.grid() make use of affine.for's ability to contain iteration arguments. For us, this is important for pipelining. Here is an example of the MLIR produced by test_reduce() (shown further down).

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    affine.for %arg1 = 0 to 20 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      // We have to reload the value
     // ... when it should be forwarded from last iteration's store
      %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
      %3 = arith.addi %2, %1 : i32
      affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
    } {loop_name = "i", op_name = "S_i_0", reduction}
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %0 : i32
  }
}

To Reproduce
The linalg dialect compounds the issue, because it lowers linalg to affine loops without an accumulator:

def test_linalg_matmul():
    N = 16
    from allo import matmul

    def kernel(A: int32[N, N], B: int32[N, N]) -> int32[N, N]:
        return matmul(A, B)

    s = allo.customize(kernel)
    print(s.module)

But even with an explicit accumulator in a single memref cell, I can't get it to be raised to SSA values:

def test_reduce():
    N = 20

    def kernel(A: int32[N]) -> int32:
        sum: int32 = 0
        for i in allo.reduction(N):
            sum += A[i]
        return sum

    s = allo.customize(kernel)
    print(s.module)

Buggy output
I was not hopeful that the existing MLIR passes would help with this issue, but I tried anyways by running mlir-opt --convert-linalg-to-affine-loops --affine-scalrep --lower-affine --convert-scf-to-cf --mem2reg

It is only expected to work on unstructured control flow, but I could not get it to work for that.

Expected behavior
Here is an example of how we do matmul in affine that uses iteration arguments to assist the pipelining pass:

  affine.for %arg3 = 0 to 16 {
      affine.for %arg4 = 0 to 16 {
        %sum = affine.for %arg5 = 0 to 16 
                iter_args(%sum_iter = %c0_i32) -> (i32) {
          %2 = affine.load %A[%arg3, %arg5] : memref<16x16xi32>
          %3 = affine.load %B[%arg5, %arg4] : memref<16x16xi32>
          %4 = arith.muli %2, %3 : i32
          %sum_next = arith.addi %4, %sum_iter : i32
          affine.yield %sum_next : i32
        }
        affine.store %sum, %C[%arg3, %arg4] : memref<16x16xi32>
      }
    }

Perhaps there are the right patterns/passes in MLIR to accomplish what we want, but I haven't found them yet. Maybe we will have to write our own pass for this or lower the AST differently.

I agree that generating iteration variables may be helpful for some compiler passes. However, it is somehow not an easy job to determine whether a variable is a reduction variable from the frontend, so we currently do not support this feature. The allo.reduction function is just an annotation, and it does not generate the loop with iteration variables.

I haven't figured out a good way to resolve this issue. Probably some sophisticated frontend analysis pass may help generate this kind of reduction loops.

I think I found a solution to this problem: https://github.com/cornell-zhang/amc-dialect/pull/64

This looks cool! Could you provide an example of the original MLIR code and the code after this pass? @andrewb1999

Yeah so this is the code before the pass:

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    affine.for %arg1 = 0 to 20 {
      %1 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
      %3 = arith.addi %2, %1 : i32
      affine.store %3, %alloc[0] {to = "sum"} : memref<1xi32>
    } {loop_name = "i", op_name = "S_i_0", reduction}
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %0 : i32
  }
}

and this is the code after the pass:

module {
  func.func @kernel(%arg0: memref<20xi32>) -> i32 attributes {itypes = "s", otypes = "s", top} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "sum"} : memref<1xi32>
    affine.store %c0_i32, %alloc[0] {to = "sum"} : memref<1xi32>
    %0 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    %1 = affine.for %arg1 = 0 to 20 iter_args(%arg2 = %0) -> (i32) {
      %3 = affine.load %arg0[%arg1] {from = "A"} : memref<20xi32>
      %4 = arith.addi %arg2, %3 : i32
      affine.yield %4 : i32
    }
    affine.store %1, %alloc[0] : memref<1xi32>
    %2 = affine.load %alloc[0] {from = "sum"} : memref<1xi32>
    return %2 : i32
  }
}

you can see the load and store on sum have been removed and replaced with iter_args and an affine.yield. The sum memref should then be able to be removed entirely using store-load forwarding.