cornell-zhang/heterocl

Incorrect loop axis reference

chhzh123 opened this issue · 4 comments

It is weird that the following two code snippet represents different meanings. The first approach does not add the newly generated loop to axis, so C.axis[1] refers to the original 1st axis. However, the second approach considers the newly generated loop, so 1 refers to the split inner loop.

// first approach
s[C].split(C.axis[0], factor=3)
s[C].split(C.axis[1], factor=3)
// second approach
s[C].split(0, factor=3)
s[C].split(1, factor=3)

The test case is from test_schedule_compute.py.

def test_compute_at():
    def _build_kernel():
        hcl.init()
        A = hcl.placeholder((10, 20, 30), name="A")
        B = hcl.compute(A.shape, lambda i, j, m: A[i, j, m] * 2, name="B")
        C = hcl.compute(B.shape, lambda ii, jj, mm: B[ii, jj, mm] + 1, name="C")
        return A, B, C

    def test_case_3():
        A, B, C = _build_kernel()
        s = hcl.create_schedule([A])
        s[B].compute_at(s[C], C.axis[2])
        s[C].split(C.axis[0], factor=3)
        s[C].split(C.axis[1], factor=3)
        ir = hcl.lower(s)
        print(ir)

    test_case_3()

The first approach generates the following code.

// attr [C] storage_scope = "global"
allocate C[int32 * 10 * 20 * 30]
produce C {
  // attr [0] extern_scope = 0
  for "stage_name"="C" (ii.outer, 0, 4) {
    for "stage_name"="C" (ii.inner, 0, 3) {
      for "stage_name"="C" (jj.outer, 0, 7) {
        for "stage_name"="C" (jj.inner, 0, 3) {
          for "stage_name"="C" (mm, 0, 30) {
            if ((jj.inner < (20 - (jj.outer*3)))) {
              if ((ii.inner < (10 - (ii.outer*3)))) {
                // attr [B] storage_scope = "global"
                allocate B[int32 * 1 * 1 * 1]
                produce B {
                  // attr [0] extern_scope = 0
                  B[0] = (A[((mm + ((jj.inner + (jj.outer*3))*30)) + ((ii.inner + (ii.outer*3))*600))]*2)
                }
                C[((mm + ((jj.inner + (jj.outer*3))*30)) + ((ii.inner + (ii.outer*3))*600))] = (B[0] + 1)
              }
            }
          }
        }
      }
    }
  }
}

And the second approach generates the code below.

// attr [C] storage_scope = "global"
allocate C[int32 * 10 * 20 * 30]
produce C {
  // attr [0] extern_scope = 0
  for "stage_name"="C" (ii.outer, 0, 4) {
    for "stage_name"="C" (ii.inner.outer, 0, 1) {
      for "stage_name"="C" (ii.inner.inner, 0, 3) {
        for "stage_name"="C" (jj, 0, 20) {
          for "stage_name"="C" (mm, 0, 30) {
            if ((ii.inner.inner < ((10 - (ii.outer*3)) - (ii.inner.outer*3)))) {
              // attr [B] storage_scope = "global"
              allocate B[int32 * 1 * 1 * 1]
              produce B {
                // attr [0] extern_scope = 0
                B[0] = (A[((mm + (jj*30)) + (((ii.inner.inner + (ii.inner.outer*3)) + (ii.outer*3))*600))]*2)
              }
              C[((mm + (jj*30)) + (((ii.inner.inner + (ii.inner.outer*3)) + (ii.outer*3))*600))] = (B[0] + 1)
            }
          }
        }
      }
    }
  }
}

I think in this case, the second approach is correct. Since after calling s[C].split(C.axis[0], factor=3), we have new loops ii.outer and ii.inner which should be attached to the C.axis. Then C.axis[1] should refer to the ii.inner one instead of the original jj one. Basically, the test program only tests if allocate B[int32 * 1 * 1 * 1] exists, so it directly passes without catching the error.

I'm not sure how TVM approaches this but I think ideally we don't want to update the axes since they are referred from the original program. Think about if we split the first loop and the second loop of the original program, the results should be the same no matter we split the first loop first or the second loop first. However, if we always update the axes, we need to change the way we write. Following is a more concrete example.

(a)
s[C].split(C.axis[0])
s[C].split(C.axis[1])
(b)
s[C].split(C.axis[1])
s[C].split(C.axis[0])

If the axis is not updated, program (a) should be the same as program (b). However, if the axes are updated, then (a) is no longer the same as (b), which is counter-intuitive (to me at least).

We can maintain a new variable like new_axis to keep track of the latest axes. To access the new loops, the ideal way is to access them from the returned variables.

It makes sense. (a) and (b) should generate the same code. The newly generated loops should always be accessed by the returned variables like the code below.

outer, inner = s[C].split(C.axis[0])
s[C].split(inner)