[Transform] `reuse_at` breaks stride-2 binary convolution
zzzDavid opened this issue · 2 comments
This issue is caused by reuse_at
implementation on stride>2 loops where the induction variable is also used in an if condition inside loop body.
Example
I'll provide an example with HCL code since comparing with numpy is easier there:
import heterocl as hcl
import heterocl.op.bnn as bnn
import numpy as np
hcl.init()
def bnn_conv(INPUT, w_conv1):
conv1 = bnn.conv2d_nchw(INPUT,w_conv1, strides=[2, 2], padding=[1, 1], name="conv1", out_dtype=hcl.Int(8))
return conv1
INPUT = hcl.placeholder((1,1,16,16),"input", hcl.UInt(1))
w_conv1 = hcl.placeholder((16,1,3,3),"w_conv1", hcl.UInt(1))
s = hcl.create_schedule([INPUT, w_conv1], bnn_conv)
LB = s.reuse_at(bnn_conv.conv1_pad, s[bnn_conv.conv1], bnn_conv.conv1.axis[2])
WB = s.reuse_at(LB, s[bnn_conv.conv1], bnn_conv.conv1.axis[3])
f = hcl.build(s)
# create random input data
np_input = np.random.randint(0, 2, size=(1,1,16,16))
np_w_conv1 = np.random.randint(0, 2, size=(16,1,3,3))
hcl_input = hcl.asarray(np_input, dtype=hcl.UInt(1))
hcl_w_conv1 = hcl.asarray(np_w_conv1, dtype=hcl.UInt(1))
hcl_output = hcl.asarray(np.zeros((1,16,8,8)), dtype=hcl.Int(8))
f(hcl_input, hcl_w_conv1, hcl_output)
np_output = hcl_output.asnumpy()
# golden
golden = np.zeros((1,16,8,8))
np_input = np.pad(np_input, ((0,0),(0,0),(1,1),(1,1)), 'constant')
for n in range(1):
for c in range(16):
for h in range(8):
for w in range(8):
for kh in range(3):
for kw in range(3):
x = h*2+kh
y = w*2+kw
if x >= 1 and x < 17 and y >= 1 and y < 17:
inp = np_input[n, 0, x, y]
wgt = np_w_conv1[c, 0, kh, kw]
inp = -1 if inp == 0 else 1
wgt = -1 if wgt == 0 else 1
golden[n, c, h, w] += inp * wgt
# compare
assert np.allclose(np_output, golden)
edit: update golden to have the same if_mac
behavior
The issue is that non-reduction loop's bound is updated, but the induction variable expressions are not.
hcl-dialect/lib/Transforms/LoopTransformations.cpp
Lines 1696 to 1697 in f2f5e43
In binary convolution, we have a if_mac
function that checks row and column axis to skip padding values. The expression looks like this:
iv * stride + reduction_iv * dilation >= padding_left
iv * stride + reduction_iv * dilation < padded_width - padding_right
And we have updated the induction variable as:
new_iv = iv * stride + reuse_distance
which means
iv = (new_iv - reuse_distance) / stride
Therefore, we just need to replace all iv
with (new_iv - reuse_distance) / stride
.