[Op] SelectOp short-circuit evaluation
Opened this issue · 7 comments
The current codegen for SelectOp does not support short-circuit evaluation, which may cause problems in some specific cases. For example, for hcl.select(i > 0, A[i - 1], 0)
, we need to make sure i > 0
first and then load A[i - 1]
. However, MLIR has a strict SSA form that does not allow us to write a single-line code but evaluate the true and false branches first. The following code shows this situation, which leads to out-of-bounds access of array A
when i = 0
.
%true = affine.load %A[%i - 1] : memref<10xi32>
%false = arith.constant 0 : i32
%cond = arith.cmpi sgt, %i, %zero : i1
%x = arith.select %cond, %true, %false : i32
General scf.if
statement may also have this problem.
Update: The above example is not the short-circuit one but is also related to the evaluation order. The short-circuit one is the below example, which is also not supported by our flow.
// expected condition: (i > 0 && A[i-1] >0)
// but currently generate the following code
// %cond1 and %cond2 are *both* evaluated before the select function
%cond1 = arith.cmp sgt %x, %zero : i1
%A_val = affine.load %A[%i - 1] : memref<10xi32>
%cond2 = arith.cmp sgt %A_val, %zero : i1
%cond3 = arith.and %cond1, %cond2 : i1
%ret = arith.select %cond3, %true, %false : i32
Outline the hcl.select
as a function? I tried this with mlir-opt
and it seems to work:
module {
func @select(%arg0: i32, %arg1: memref<10xi32>) -> i32 {
%c0_i32 = arith.constant 0 : i32
%0 = arith.cmpi sgt, %arg0, %c0_i32 : i32
cond_br %0, ^bb1(%arg0 : i32), ^bb2
^bb1(%1: i32): // pred: ^bb0
%2 = arith.index_cast %1 : i32 to index
%3 = affine.load %arg1[%2 - 1] : memref<10xi32>
return %3 : i32
^bb2: // pred: ^bb0
%c0_i32_0 = arith.constant 0 : i32
return %c0_i32_0 : i32
}
}
Proposal: Use Nested if
for Short-Circuit Eval
Build if
nest to support short-circuit behaviour of multiple conditions in hcl.select
What issue might occur if we don't support short-circuit
Consider this condition
hcl.select(hcl.and_(i > 1, A[i - 1] > 0), ..., ...)
If we build this into three conditions:
%cond1 = cmp lt %i %one
%ele = memref.load A[%i - 1]
%cond2 = cmp gt %ele %zero
%cond3 = and %cond1 %cond2
when %i
evaluates to zero, memref.load
would be evaluated and raise a segmentation fault because it is supposed to be short-circuited but didn't.
How should we implement the short-circuit behaviour?
We can turn that into nested if
operations:
// initialize SelectOp's return value here
%ret_value = some init value
%cond1 = cmp lt %i %one
scf.if (%cond1) {
%ele = memref.load A[%i - 1]
%cond2 = cmp gt %ele %zero
scf.if (%cond2) {
// true branch
// update return value here
...
}
} scf.else {
// false branch
// update return value here
...
}
We need to add a node for hcl.and_
in our AST
We are building hcl.select
offline, meaning we build an AST first, and then generate the IR by traversing the AST.
The current AST has one drawback for such case: when hcl.and_
has more than one condition, it turns into a tree of And
nodes:
and
/ \
and cond1
/ \
cond3 cond2
This is inconvenient to build such if
nest:
if (cond1) {
if (cond2) {
if (cond3) {
}
}
}
because we lost the information of how many conditions we have for hcl.and_
at the top level after we build the AST, which can only be recovered after we traverse the and
tree.
We can add a single node to represent hcl.and_
and let it have a list of condition expressions to make building if
nest easier.
I think it might make sense for simplicity reasons to separate the evaluation of the condition and the true/false branches. I think you can do this with scf.yield.
%ret_value = some init value
%cond1 = cmp lt %i %one
%cond = scf.if %cond1 -> (i1) {
%ele = memref.load A[%i - 1]
%cond2 = cmp gt %ele %zero
scf.yield %cond2 : i1
} else {
%false = arith.constant 0 : i1
scf.yield %false : i1
}
%res = scf.if %cond -> (out_type) {
// true branch
} else {
//false branch
}
It might also make sense to define a specific hcl.select that functions almost identically to scf.if but requires the else branch. This could simplify code generation because it will separate normal if statements from hcl.selects.
%ret_value = some init value
%cond1 = cmp lt %i %one
%cond = scf.if %cond1 -> (i1) {
%ele = memref.load A[%i - 1]
%cond2 = cmp gt %ele %zero
scf.yield %cond2 : i1
} else {
%false = arith.constant 0 : i1
scf.yield %false : i1
}
%res = hcl.select %cond -> (out_type) {
// true branch
} else {
//false branch
}
@andrewb1999 Thanks for your suggestions! I think separating the condition and the true/false branches is a good idea, but I have two questions:
- What if we have more than two conditions? Do we generate sequential or nested
scf.if
for%cond
? For example, forif (cond1 and cond2 and cond3)
, we have the following two rewrite methods. For sequential if, it may incur redundant evaluations. For nested if, it may be tricky to yield the final result from the inner-mostif
to the outer-mostif
.
// sequential if
%cond_12 = scf.if %cond1 () {
scf.yield %cond2
} else {
scf.yield %false
}
%cond_123 = scf.if %cond12 () {
scf.yield %cond3
} else {
scf.yield %false
}
// nested if
%cond = scf.if %cond1 () {
%cond_23 = scf.if %cond2 () {
scf.yield %cond3
} else {
scf.yield %false
}
scf.yield %cond_23
} else {
scf.yield %false
}
- If we define our own
hcl.select
operation, we probably need to also definehcl.yield
? Otherwise, we do not know the return values of the true and false branches.
- The more I think about this, the more I think we need better separation between the representation of hcl.cond and down-stream optimizations. Something like nested scf.if seems like the best way to implement this for the llvm backend, but it seems like a pretty disappointing representation for hls code generation. It seems like it would be nice to have something like the following:
%cond = hcl.and {
hcl.yield %cond1
} and {
hcl.yield %cond2
} and {
hcl.yield %cond3
}
This should be possible using the VariadicRegion class in mlir tablegen (to support an arbitrary number of and regions). We can then imagine semantics of this new operation to be short-circuting in the same way as && in C. We can also nest conditions in a fairly natural way. When generating HLS code, it is a direct translation to cond1 && cond2 && cond3. When generating LLVM, we can translate it to a nested if as described above, or directly to basic blocks that represent short-circuiting, in the same way we could represent short-circuiting in LLVM itself.
Obviously this form of implementation requires more work in tablegen, but to me this is the whole point of using MLIR. It seems worth it to leverage MLIRs features to provide a better high level representation that we can translate to LLVM later.
- Yes we will need our own hcl.yield. This should be most copying and pasting scf.yield and changing some names.
Obviously this form of implementation requires more work in tablegen, but to me this is the whole point of using MLIR. It seems worth it to leverage MLIRs features to provide a better high level representation that we can translate to LLVM later.
Right, I agree with that. Basically, we want to retain as much information as we can in MLIR. Those scf.if
statements complicate the structure and we still need to recover them back during codegen, so probably introducing our operations will be a good way to tackle these problems.
I added hcl.and
, hcl.or
, and hcl.yield
with multiple regions to support short-circuit evaluation. Test cases are under test/Operations/logicops
. An example with hcl.and
:
https://github.com/cornell-zhang/hcl-dialect-prototype/blob/2b91125309baad639f59a6496b39b64e0534b960/test/Operations/logicops/and.mlir#L2-L14
Next step I'll implement their lowering passes for LLVM backend.