tlc-pack/relax

[MetaSchedule][Hexagon] conv2d produces different results after tuning

psrivas2 opened this issue · 4 comments

The following PrimFunc produces different results after tuning on hexagon.

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
    # body
    # with T.block("root")
    for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
        with T.block("conv2d_nhwc"):
            nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
            T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
            T.writes(conv2d_nhwc[nn, yy, xx, ff])
            with T.init():
                conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
            conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]

Post tuning the PrimFunc is transformed to:

@T.prim_func
def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
    # function attr dict
    T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
    # body
    # with T.block("root")
    conv2d_nhwc_global = T.alloc_buffer([1, 112, 112, 64], dtype="float16")
    for i0_0_i1_0_i2_0_fused in T.parallel(196, annotations={"pragma_auto_unroll_max_step":T.int64(512), "pragma_unroll_explicit":T.int64(1)}):
        for i3_0 in T.serial(1):
            for i0_1_init, i1_1_init, i2_1_init, i3_1_init, i0_2_init, i1_2_init, i2_2_init in T.grid(1, 2, 16, 1, 1, 2, 1):
                for i3_2_fused_init in T.vectorized(64):
                    with T.block("conv2d_nhwc_init"):
                        nn = T.axis.spatial(1, i0_1_init + i0_2_init)
                        yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1_init * 2 + i1_2_init)
                        xx = T.axis.spatial(112, i2_2_init + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1_init)
                        ff = T.axis.spatial(64, i3_0 * 64 + i3_1_init * 64 + i3_2_fused_init)
                        T.reads()
                        T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                        T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                        conv2d_nhwc_global[nn, yy, xx, ff] = T.float16(0)
            for i4_0, i5_0, i6_0 in T.grid(1, 7, 1):
                for i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2 in T.grid(1, 2, 16, 1, 7, 1, 3, 1, 2, 1):
                    for i3_2_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_update"):
                            nn = T.axis.spatial(1, i0_1 + i0_2)
                            yy = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + i1_1 * 2 + i1_2)
                            xx = T.axis.spatial(112, i2_2 + i0_0_i1_0_i2_0_fused % 7 * 16 + i2_1)
                            ff = T.axis.spatial(64, i3_0 * 64 + i3_1 * 64 + i3_2_fused)
                            ry = T.axis.reduce(7, i4_0 * 7 + i4_1)
                            rx = T.axis.reduce(7, i5_0 + i5_1)
                            rc = T.axis.reduce(3, i6_0 * 3 + i6_1)
                            T.reads(conv2d_nhwc_global[nn, yy, xx, ff], lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                            T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
                            T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
                            conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff]
                for ax0, ax1, ax2 in T.grid(1, 4, 16):
                    for ax3_fused in T.vectorized(64):
                        with T.block("conv2d_nhwc_global"):
                            v0 = T.axis.spatial(1, ax0)
                            v1 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused // 7 * 4 + ax1)
                            v2 = T.axis.spatial(112, i0_0_i1_0_i2_0_fused % 7 * 16 + ax2)
                            v3 = T.axis.spatial(64, ax3_fused)
                            T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
                            T.writes(conv2d_nhwc[v0, v1, v2, v3])
                            conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]

The two PrimFuncs produce different results on hexagon hardware. This needs to be investigated.

Thanks @psrivas2 for reporting the issue!

Two questions that could help us know more about the context:

  • Is it hexagon specific, i.e. if we tune conv2d on cpu and gpu, will this incorrect results also happen?
  • Is it only conv2d, i.e. if we tune other kernels on hexagon, will the before/after tuned kernels give different results?

First, it is hexagon specific. On CPU the tuned kernel output is same as untuned output.
Second, I have only observed this behavior for this specific kernel. For example, after fusion, resnet has around 31 PrimFuncs. Out of those 31, only 1 PrimFunc which had the above block as one of the fused operations was producing different results than untuned PrimFuncs.

In addition to that, this is definitely some incorrect transformation of untuned PrimFunc, as the two PrimFuncs shown above give different results even on CPU.

I think I have narrowed it down to the reordering of loops.

On Hexagon the following two modules which differ only in the order of loops i3 & i4 produce different numeric results. The max difference in values is 0.5 and the mean difference is 0.0708. This is only happening for fp16 dtype.

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv

Reorder loops i3 & i4

sch = tvm.tir.Schedule(mod)
b0 = sch.get_block("conv2d_nhwc", func_name="conv2d")
i0, i1, i2, i3, i4, i5, i6 = sch.get_loops(b0)
sch.reorder(i4, i3)

the modified module looks like below

@tvm.script.ir_module
class TuningBug:
    @T.prim_func
    def conv2d(lv1: T.Buffer[(1, 230, 230, 3), "float16"], param_0: T.Buffer[(7, 7, 3, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float16"]):
        # function attr dict
        T.func_attr({"tir.noalias": True, "global_symbol": "conv2d"})
        # body
        # with T.block("root")
        for i0, i1, i2, i4, i3, i5, i6 in T.grid(1, 112, 112, 7, 64, 7, 3):
            with T.block("conv2d_nhwc"):
                nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(lv1[nn, yy * 2 + ry, xx * 2 + rx, rc], param_0[ry, rx, rc, ff])
                T.writes(conv2d_nhwc[nn, yy, xx, ff])
                with T.init():
                    conv2d_nhwc[nn, yy, xx, ff] = T.float16(0)
                conv2d_nhwc[nn, yy, xx, ff] = (conv2d_nhwc[nn, yy, xx, ff] + lv1[nn, yy * 2 + ry, xx * 2 + rx, rc] * param_0[ry, rx, rc, ff])

    @R.function
    def main(lv1: R.Tensor[(1, 230, 230, 3), "float16"], param_0: R.Tensor[(T.int64(7), T.int64(7), T.int64(3), T.int64(64)), "float16"]):
        with R.dataflow():
            gv = R.call_tir(conv2d, (lv1, param_0), (1, 112, 112, 64), dtype="float16")
            R.output(gv)
        return gv