Bug found in Simplify IfThenElse pass, due to overaggressive optimizations
wyanzhao opened this issue · 1 comments
wyanzhao commented
Hi, I noticed a bug fix in #456, which tries to fix wrongly optimizes away if condition.
But, I am still be able to trigger similar bugs through the following code:
import heterocl as hcl
def test_mask():
def two_stage(A):
var = hcl.scalar(0, "v", dtype=hcl.Int(32))
with hcl.if_(var == 0):
var.v = var - 33
hcl.print(var,"A\n")
with hcl.if_(var == -33):
hcl.print((var),"B\n")
with hcl.else_():
var.v = var - 1
# this condition should not be optimized away
with hcl.if_(var == 0):
hcl.print((),"B\n")
return A
A = hcl.placeholder((2,), "A", dtype=hcl.UInt(16))
s = hcl.create_schedule([A], two_stage)
print(hcl.lower(s))
Which should generate IR code like this:
produce v {
// attr [0] extern_scope = 0
for "stage_name"="v" (x, 0, 1) {
v[x] = 0
}
}
if ((v[0] == 0)) {
v[0] = (v[0] + -33)
print: v[0]
if ((v[0] == -33)) {
print: v[0]
}
} else {
v[0] = (v[0] + -1)
if ((v[0] == 0)) {
print:
}
}
But now, HeteroCL generates IR like this:
produce v {
// attr [0] extern_scope = 0
for "stage_name"="v" (x, 0, 1) {
v[x] = 0
}
}
if ((v[0] == 0)) {
v[0] = -33
print: 0
} else {
v[0] = (v[0] + -1)
if ((v[0] == 0)) {
print:
}
}
The second Print Stmt is optimized away, but its if condition could be met.
This bug is trigger by the code in IfThenElse pass in Simplify
else if (eq && is_const(eq->b) && !or_chain) {
// some_expr = const
then_case = substitute(eq->a, eq->b, then_case);
}
wyanzhao commented
Another example can trigger bug caused by
else if (ne && is_const(ne->b) && !and_chain) {
// some_expr != const
else_case = substitute(ne->a, ne->b, else_case);
}
The code is:
import heterocl as hcl
def test_mask():
def two_stage(A):
var = hcl.scalar(1, "v", dtype=hcl.Int(32))
with hcl.if_(var != 1):
hcl.print(var,"A\n")
with hcl.else_():
var.v = var - 1
hcl.print((var),"B\n")
with hcl.if_(var == 0):
hcl.print((var),"C\n")
return A
A = hcl.placeholder((2,), "A", dtype=hcl.UInt(16))
s = hcl.create_schedule([A], two_stage)
print(hcl.lower(s))
if __name__ == "__main__":
test_mask()
The output is
produce v {
// attr [0] extern_scope = 0
for "stage_name"="v" (x, 0, 1) {
v[x] = 1
}
}
if ((v[0] != 1)) {
print: v[0]
} else {
v[0] = 0
print: 1
}
Correct output should be
produce v {
// attr [0] extern_scope = 0
for "stage_name"="v" (x, 0, 1) {
v[x] = 1
}
}
if ((v[0] != 1)) {
print: v[0]
} else {
v[0] = (v[0] + -1)
print: v[0]
if ((v[0] == 0)) {
print: 0
}
}
The if condition in the third Print Stmt can also be met.