cornell-zhang/hcl-dialect

[Op] SelectOp short-circuit evaluation

Opened this issue · 7 comments

The current codegen for SelectOp does not support short-circuit evaluation, which may cause problems in some specific cases. For example, for hcl.select(i > 0, A[i - 1], 0), we need to make sure i > 0 first and then load A[i - 1]. However, MLIR has a strict SSA form that does not allow us to write a single-line code but evaluate the true and false branches first. The following code shows this situation, which leads to out-of-bounds access of array A when i = 0.

%true = affine.load %A[%i - 1] : memref<10xi32>
%false = arith.constant 0 : i32
%cond = arith.cmpi sgt, %i, %zero : i1
%x = arith.select %cond, %true, %false : i32

General scf.if statement may also have this problem.

Update: The above example is not the short-circuit one but is also related to the evaluation order. The short-circuit one is the below example, which is also not supported by our flow.

// expected condition: (i > 0 && A[i-1] >0)
// but currently generate the following code
// %cond1 and %cond2 are *both* evaluated before the select function
%cond1 = arith.cmp sgt %x, %zero : i1
%A_val = affine.load %A[%i - 1] : memref<10xi32>
%cond2 = arith.cmp sgt %A_val, %zero : i1
%cond3 = arith.and %cond1, %cond2 : i1
%ret = arith.select %cond3, %true, %false : i32

Outline the hcl.select as a function? I tried this with mlir-opt and it seems to work:

module {
  func @select(%arg0: i32, %arg1: memref<10xi32>) -> i32 {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
    cond_br %0, ^bb1(%arg0 : i32), ^bb2
  ^bb1(%1: i32):  // pred: ^bb0
    %2 = arith.index_cast %1 : i32 to index
    %3 = affine.load %arg1[%2 - 1] : memref<10xi32>
    return %3 : i32
  ^bb2:  // pred: ^bb0
    %c0_i32_0 = arith.constant 0 : i32
    return %c0_i32_0 : i32
  }
}

Proposal: Use Nested if for Short-Circuit Eval

Build if nest to support short-circuit behaviour of multiple conditions in hcl.select

What issue might occur if we don't support short-circuit

Consider this condition

hcl.select(hcl.and_(i > 1, A[i - 1] > 0), ..., ...)

If we build this into three conditions:

%cond1 = cmp lt %i %one
%ele = memref.load A[%i - 1]
%cond2 = cmp gt %ele %zero
%cond3 = and %cond1 %cond2

when %i evaluates to zero, memref.load would be evaluated and raise a segmentation fault because it is supposed to be short-circuited but didn't.

How should we implement the short-circuit behaviour?

We can turn that into nested if operations:

// initialize SelectOp's return value here
%ret_value = some init value
%cond1 = cmp lt %i %one
scf.if (%cond1) {
  %ele = memref.load A[%i - 1]
  %cond2 = cmp gt %ele %zero
  scf.if (%cond2) {
    // true branch
    // update return value here
    ...
  }
} scf.else {
   // false branch
   // update return value here
   ...
}

We need to add a node for hcl.and_ in our AST

We are building hcl.select offline, meaning we build an AST first, and then generate the IR by traversing the AST.
The current AST has one drawback for such case: when hcl.and_ has more than one condition, it turns into a tree of And nodes:

           and
          /    \
        and  cond1
       /   \
   cond3   cond2

This is inconvenient to build such if nest:

if (cond1) {
  if (cond2) {
    if (cond3) {
    }
  }
}

because we lost the information of how many conditions we have for hcl.and_ at the top level after we build the AST, which can only be recovered after we traverse the and tree.

We can add a single node to represent hcl.and_ and let it have a list of condition expressions to make building if nest easier.

I think it might make sense for simplicity reasons to separate the evaluation of the condition and the true/false branches. I think you can do this with scf.yield.

%ret_value = some init value
%cond1 = cmp lt %i %one
%cond = scf.if %cond1 -> (i1) {
  %ele = memref.load A[%i - 1]
  %cond2 = cmp gt %ele %zero
  scf.yield %cond2 : i1 
} else {
  %false = arith.constant 0 : i1
  scf.yield %false : i1
}
%res = scf.if %cond -> (out_type) {
  // true branch
} else {
  //false branch
}

It might also make sense to define a specific hcl.select that functions almost identically to scf.if but requires the else branch. This could simplify code generation because it will separate normal if statements from hcl.selects.

%ret_value = some init value
%cond1 = cmp lt %i %one
%cond = scf.if %cond1 -> (i1) {
  %ele = memref.load A[%i - 1]
  %cond2 = cmp gt %ele %zero
  scf.yield %cond2 : i1 
} else {
  %false = arith.constant 0 : i1
  scf.yield %false : i1
}
%res = hcl.select %cond -> (out_type) {
  // true branch
} else {
  //false branch
}

@andrewb1999 Thanks for your suggestions! I think separating the condition and the true/false branches is a good idea, but I have two questions:

  1. What if we have more than two conditions? Do we generate sequential or nested scf.if for %cond? For example, for if (cond1 and cond2 and cond3), we have the following two rewrite methods. For sequential if, it may incur redundant evaluations. For nested if, it may be tricky to yield the final result from the inner-most if to the outer-most if.
// sequential if
%cond_12 = scf.if %cond1 () {
  scf.yield %cond2
} else {
  scf.yield %false
}
%cond_123 = scf.if %cond12 () {
  scf.yield %cond3
} else {
  scf.yield %false
}

// nested if
%cond = scf.if %cond1 () {
  %cond_23 = scf.if %cond2 () {
    scf.yield %cond3
  } else {
    scf.yield %false
  }
  scf.yield %cond_23
} else {
  scf.yield %false
}
  1. If we define our own hcl.select operation, we probably need to also define hcl.yield? Otherwise, we do not know the return values of the true and false branches.
  1. The more I think about this, the more I think we need better separation between the representation of hcl.cond and down-stream optimizations. Something like nested scf.if seems like the best way to implement this for the llvm backend, but it seems like a pretty disappointing representation for hls code generation. It seems like it would be nice to have something like the following:
%cond = hcl.and {
  hcl.yield %cond1
} and {
  hcl.yield %cond2
} and {
  hcl.yield %cond3
}

This should be possible using the VariadicRegion class in mlir tablegen (to support an arbitrary number of and regions). We can then imagine semantics of this new operation to be short-circuting in the same way as && in C. We can also nest conditions in a fairly natural way. When generating HLS code, it is a direct translation to cond1 && cond2 && cond3. When generating LLVM, we can translate it to a nested if as described above, or directly to basic blocks that represent short-circuiting, in the same way we could represent short-circuiting in LLVM itself.

Obviously this form of implementation requires more work in tablegen, but to me this is the whole point of using MLIR. It seems worth it to leverage MLIRs features to provide a better high level representation that we can translate to LLVM later.

  1. Yes we will need our own hcl.yield. This should be most copying and pasting scf.yield and changing some names.

Obviously this form of implementation requires more work in tablegen, but to me this is the whole point of using MLIR. It seems worth it to leverage MLIRs features to provide a better high level representation that we can translate to LLVM later.

Right, I agree with that. Basically, we want to retain as much information as we can in MLIR. Those scf.if statements complicate the structure and we still need to recover them back during codegen, so probably introducing our operations will be a good way to tackle these problems.

I added hcl.and, hcl.or, and hcl.yield with multiple regions to support short-circuit evaluation. Test cases are under test/Operations/logicops. An example with hcl.and:
https://github.com/cornell-zhang/hcl-dialect-prototype/blob/2b91125309baad639f59a6496b39b64e0534b960/test/Operations/logicops/and.mlir#L2-L14

Next step I'll implement their lowering passes for LLVM backend.