cornell-zhang/heterocl

Bug found in Simplify IfThenElse pass, due to overaggressive optimizations

wyanzhao opened this issue · 1 comments

Hi, I noticed a bug fix in #456, which tries to fix wrongly optimizes away if condition.
But, I am still be able to trigger similar bugs through the following code:

import heterocl as hcl

def test_mask():

    def two_stage(A):
        var = hcl.scalar(0, "v", dtype=hcl.Int(32))
        with hcl.if_(var == 0):
            var.v = var - 33
            hcl.print(var,"A\n")
            with hcl.if_(var == -33):
                hcl.print((var),"B\n")
        with hcl.else_():
            var.v = var - 1
            # this condition should not be optimized away
            with hcl.if_(var == 0):
                hcl.print((),"B\n")
        return A
    
    A = hcl.placeholder((2,), "A", dtype=hcl.UInt(16))
    s = hcl.create_schedule([A], two_stage)
    print(hcl.lower(s))

Which should generate IR code like this:

  produce v {
    // attr [0] extern_scope = 0
    for "stage_name"="v" (x, 0, 1) {
      v[x] = 0
    }
  }
  if ((v[0] == 0)) {
    v[0] = (v[0] + -33)
    print: v[0]
    if ((v[0] == -33)) {
      print: v[0]
    }
  } else {
    v[0] = (v[0] + -1)
    if ((v[0] == 0)) {
      print:
    }
  }

But now, HeteroCL generates IR like this:

  produce v {
    // attr [0] extern_scope = 0
    for "stage_name"="v" (x, 0, 1) {
      v[x] = 0
    }
  }
  if ((v[0] == 0)) {
    v[0] = -33
    print: 0
  } else {
    v[0] = (v[0] + -1)
    if ((v[0] == 0)) {
      print:
    }
  }

The second Print Stmt is optimized away, but its if condition could be met.

This bug is trigger by the code in IfThenElse pass in Simplify

 else if (eq && is_const(eq->b) && !or_chain) {
          // some_expr = const
          then_case = substitute(eq->a, eq->b, then_case);
        }

Another example can trigger bug caused by

else if (ne && is_const(ne->b) && !and_chain) {
          // some_expr != const
          else_case = substitute(ne->a, ne->b, else_case);
        }

The code is:

import heterocl as hcl

def test_mask():

    def two_stage(A):
        var = hcl.scalar(1, "v", dtype=hcl.Int(32))
        with hcl.if_(var != 1):
            hcl.print(var,"A\n")

        with hcl.else_():
            var.v = var - 1
            hcl.print((var),"B\n")
            with hcl.if_(var == 0):
                hcl.print((var),"C\n")
        return A
    
    A = hcl.placeholder((2,), "A", dtype=hcl.UInt(16))
    s = hcl.create_schedule([A], two_stage)
    print(hcl.lower(s))

if __name__ == "__main__":
    test_mask()

The output is

  produce v {
    // attr [0] extern_scope = 0
    for "stage_name"="v" (x, 0, 1) {
      v[x] = 1
    }
  }
  if ((v[0] != 1)) {
    print: v[0]
  } else {
    v[0] = 0
    print: 1
  }

Correct output should be

 produce v {
    // attr [0] extern_scope = 0
    for "stage_name"="v" (x, 0, 1) {
      v[x] = 1
    }
  }
  if ((v[0] != 1)) {
    print: v[0]
  } else {
    v[0] = (v[0] + -1)
    print: v[0]
    if ((v[0] == 0)) {
      print: 0
    }
  }

The if condition in the third Print Stmt can also be met.