Check_ReorderLoops fails when reduction is not in reduction syntax
SamirDroubi opened this issue · 4 comments
The following program returns unknown
result on z3 4.12.4.0 (and unsat result on z3 4.12.5.0) when trying to reorder loops ioo
and ioi
:
@proc
def foo(n: size, x: [f32][n] @ DRAM, y: [f32][n] @ DRAM,
result: f32 @ DRAM):
assert stride(x, 0) == 1
assert stride(y, 0) == 1
result_: f32 @ DRAM
result_ = 0.0
var0: f32[8] @ AVX2
for ii in seq(0, 8):
var0[ii] = 0.0
for ioo in seq(0, ((7 + n) / 8 - 1) / 4):
for ioi in seq(0, 4):
var1: f32[8] @ AVX2
var2: f32[8] @ AVX2
for ii in seq(0, 8):
var1[ii] = x[ii + 8 * (4 * ioo + ioi)]
for ii in seq(0, 8):
var2[ii] = y[ii + 8 * (4 * ioo + ioi)]
for ii in seq(0, 8):
var0[ii] = var1[ii] * var2[ii] + var0[ii]
for ii in seq(0, 8):
result_ += var0[ii]
result = result_
However, switching the following assignment to a reduction:
var0[ii] = var1[ii] * var2[ii] + var0[ii]
--->
var0[ii] += var1[ii] * var2[ii]
Let's it pass (on both versions). Is this an expected behavior or a bug?
Here is the debug info when the statement is in assignment form:
assume M(stride(x,0) == 1 ∧ stride(y,0) == 1 ∧ n > 0)
to verify
∀ioo,∀ioi,∀ioo,∀ioi,M(0 ≤ ioo ∧ ioo < ((7 + n) / 8 - 1) / 4 ∧ (0 ≤ ioi ∧ ioi < 4) ∧ (0 ≤ ioo ∧ ioo < ((7 + n) / 8 - 1) / 4) ∧ (0 ≤ ioi ∧ ioi < 4) ∧ ioo < ioo ∧ ioi < ioi) ⇒ D(∀i0,¬((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∧ ((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∨ (∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii)))) ∧ D(∀i0,¬((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∧ ((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∨ (∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii))))
ForAll(ioo_257,
ForAll(ioi_258,
ForAll(ioo_293,
ForAll(ioi_294,
Implies(And(And(And(And(And(And(0 <=
ioo_257,
ForAll([div_tmp_313,
div_tmp_314],
Implies(And(And(8*
div_tmp_313 <=
7 + n_250,
7 + n_250 <
8*(div_tmp_313 + 1)),
And(4*div_tmp_314 <=
div_tmp_313 - 1,
div_tmp_313 - 1 <
4*(div_tmp_314 + 1))),
ioo_257 <
div_tmp_314))),
And(0 <= ioi_258,
4 > ioi_258)),
And(0 <= ioo_293,
ForAll([div_tmp_315,
div_tmp_316],
Implies(And(And(8*
div_tmp_315 <=
7 + n_250,
7 + n_250 <
8*(div_tmp_315 + 1)),
And(4*div_tmp_316 <=
div_tmp_315 - 1,
div_tmp_315 - 1 <
4*(div_tmp_316 + 1))),
ioo_293 <
div_tmp_316)))),
And(0 <= ioi_294,
4 > ioi_294)),
ioo_257 < ioo_293),
ioi_294 < ioi_258),
And(ForAll(i0_307,
Not(And(Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_307 == ii_263)),
Or(Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_307 == ii_263)),
Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_307 == ii_263)))))),
ForAll(i0_310,
Not(And(Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_310 == ii_263)),
Or(Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_310 == ii_263)),
Exists(ii_263,
And(And(0 <= ii_263,
8 > ii_263),
i0_310 == ii_263))))))))))))
smtlib2
; benchmark generated from python API
(set-info :status unknown)
(declare-fun n_250 () Int)
(declare-fun y_stride_0_292 () Int)
(declare-fun x_stride_0_291 () Int)
(assert
(and (and (= 1 x_stride_0_291) (= 1 y_stride_0_292)) (< 0 n_250)))
(assert
(let (($x292 (forall ((ioo_257 Int) )(forall ((ioi_258 Int) )(forall ((ioo_293 Int) )(forall ((ioi_294 Int) )(let (($x271 (forall ((i0_310 Int) )(let (($x327 (exists ((ii_263 Int) )(and (and (<= 0 ii_263) (> 8 ii_263)) (= i0_310 ii_263)))
))
(not (and $x327 (or $x327 $x327)))))
))
(let (($x277 (forall ((i0_307 Int) )(let (($x327 (exists ((ii_263 Int) )(and (and (<= 0 ii_263) (> 8 ii_263)) (= i0_307 ii_263)))
))
(not (and $x327 (or $x327 $x327)))))
))
(let (($x303 (forall ((div_tmp_315 Int) (div_tmp_316 Int) )(let (($x356 (and (<= (* 4 div_tmp_316) (- div_tmp_315 1)) (< (- div_tmp_315 1) (* 4 (+ div_tmp_316 1))))))
(let (($x326 (and (<= (* 8 div_tmp_315) (+ 7 n_250)) (< (+ 7 n_250) (* 8 (+ div_tmp_315 1))))))
(let (($x250 (and $x326 $x356)))
(=> $x250 (< ioo_293 div_tmp_316))))))
))
(let (($x146 (forall ((div_tmp_313 Int) (div_tmp_314 Int) )(let (($x356 (and (<= (* 4 div_tmp_314) (- div_tmp_313 1)) (< (- div_tmp_313 1) (* 4 (+ div_tmp_314 1))))))
(let (($x326 (and (<= (* 8 div_tmp_313) (+ 7 n_250)) (< (+ 7 n_250) (* 8 (+ div_tmp_313 1))))))
(let (($x250 (and $x326 $x356)))
(=> $x250 (< ioo_257 div_tmp_314))))))
))
(let (($x148 (and (and (<= 0 ioo_257) $x146) (and (<= 0 ioi_258) (> 4 ioi_258)))))
(let (($x305 (and (and $x148 (and (<= 0 ioo_293) $x303)) (and (<= 0 ioi_294) (> 4 ioi_294)))))
(=> (and (and $x305 (< ioo_257 ioo_293)) (< ioi_294 ioi_258)) (and $x277 $x271)))))))))
)
)
)
))
(not $x292)))
(check-sat)
Here is the debug info when the statement is in reduce form:
assume M(stride(x,0) == 1 ∧ stride(y,0) == 1 ∧ n > 0)
to verify
True
True
smtlib2
; benchmark generated from python API
(set-info :status unknown)
(declare-fun n_250 () Int)
(declare-fun y_stride_0_288 () Int)
(declare-fun x_stride_0_287 () Int)
(assert
(and (and (= 1 x_stride_0_287) (= 1 y_stride_0_288)) (< 0 n_250)))
(assert
(not true))
(check-sat)
It seems like there is two bugs? The case that is passing seems to be dispatching a query without any effects to verify?
Here is a minimal test that replicates the bug. z3-solver (4.12.4.0) reports unsat (instead of unknown on this test):
@proc
def foo(n: size, a: f32):
for i in seq(0, n):
for j in seq(0, 4):
a = a + 1.0
foo = reorder_loops(foo, foo.find_loop("i"))
I think this might be an expected behavior. The body here doesn't commute with itself because the general case where one statement is reading a
and writing a
doesn't commute with another statement that reading a
and writing a
. Here they do precisely because it is a reduction. However, there is no logic in the analysis to detect that an assign represents a reduction.
One idea is to have an internal rewrite pass within the analysis the folds all assignments into a reduction form if possible before generating the SMT queries. The rewrites shouldn't propagate to the caller scheduling operation itself.