exo-lang/exo

Check_ReorderLoops fails when reduction is not in reduction syntax

SamirDroubi opened this issue · 4 comments

The following program returns unknown result on z3 4.12.4.0 (and unsat result on z3 4.12.5.0) when trying to reorder loops ioo and ioi:

    @proc
    def foo(n: size, x: [f32][n] @ DRAM, y: [f32][n] @ DRAM,
                      result: f32 @ DRAM):
        assert stride(x, 0) == 1
        assert stride(y, 0) == 1
        result_: f32 @ DRAM
        result_ = 0.0
        var0: f32[8] @ AVX2
        for ii in seq(0, 8):
            var0[ii] = 0.0
        for ioo in seq(0, ((7 + n) / 8 - 1) / 4):
            for ioi in seq(0, 4):
                var1: f32[8] @ AVX2
                var2: f32[8] @ AVX2
                for ii in seq(0, 8):
                    var1[ii] = x[ii + 8 * (4 * ioo + ioi)]
                for ii in seq(0, 8):
                    var2[ii] = y[ii + 8 * (4 * ioo + ioi)]
                for ii in seq(0, 8):
                    var0[ii] = var1[ii] * var2[ii] + var0[ii]
        for ii in seq(0, 8):
            result_ += var0[ii]
        result = result_

However, switching the following assignment to a reduction:

var0[ii] = var1[ii] * var2[ii] + var0[ii]
--->
var0[ii] += var1[ii] * var2[ii]

Let's it pass (on both versions). Is this an expected behavior or a bug?


Here is the debug info when the statement is in assignment form:

assume  M(stride(x,0) == 1 ∧ stride(y,0) == 1 ∧ n > 0)
to verify
∀ioo,∀ioi,∀ioo,∀ioi,M(0 ≤ ioo ∧ ioo < ((7 + n) / 8 - 1) / 4 ∧ (0 ≤ ioi ∧ ioi < 4) ∧ (0 ≤ ioo ∧ ioo < ((7 + n) / 8 - 1) / 4) ∧ (0 ≤ ioi ∧ ioi < 4) ∧ ioo < ioo ∧ ioi < ioi) ⇒ D(∀i0,¬((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∧ ((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∨ (∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii)))) ∧ D(∀i0,¬((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∧ ((∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii) ∨ (∃ii,0 ≤ ii ∧ ii < 8 ∧ i0 == ii))))
ForAll(ioo_257,
       ForAll(ioi_258,
              ForAll(ioo_293,
                     ForAll(ioi_294,
                            Implies(And(And(And(And(And(And(0 <=
                                        ioo_257,
                                        ForAll([div_tmp_313,
                                        div_tmp_314],
                                        Implies(And(And(8*
                                        div_tmp_313 <=
                                        7 + n_250,
                                        7 + n_250 <
                                        8*(div_tmp_313 + 1)),
                                        And(4*div_tmp_314 <=
                                        div_tmp_313 - 1,
                                        div_tmp_313 - 1 <
                                        4*(div_tmp_314 + 1))),
                                        ioo_257 <
                                        div_tmp_314))),
                                        And(0 <= ioi_258,
                                        4 > ioi_258)),
                                        And(0 <= ioo_293,
                                        ForAll([div_tmp_315,
                                        div_tmp_316],
                                        Implies(And(And(8*
                                        div_tmp_315 <=
                                        7 + n_250,
                                        7 + n_250 <
                                        8*(div_tmp_315 + 1)),
                                        And(4*div_tmp_316 <=
                                        div_tmp_315 - 1,
                                        div_tmp_315 - 1 <
                                        4*(div_tmp_316 + 1))),
                                        ioo_293 <
                                        div_tmp_316)))),
                                        And(0 <= ioi_294,
                                        4 > ioi_294)),
                                        ioo_257 < ioo_293),
                                        ioi_294 < ioi_258),
                                    And(ForAll(i0_307,
                                        Not(And(Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_307 == ii_263)),
                                        Or(Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_307 == ii_263)),
                                        Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_307 == ii_263)))))),
                                        ForAll(i0_310,
                                        Not(And(Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_310 == ii_263)),
                                        Or(Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_310 == ii_263)),
                                        Exists(ii_263,
                                        And(And(0 <= ii_263,
                                        8 > ii_263),
                                        i0_310 == ii_263))))))))))))
smtlib2
; benchmark generated from python API
(set-info :status unknown)
(declare-fun n_250 () Int)
(declare-fun y_stride_0_292 () Int)
(declare-fun x_stride_0_291 () Int)
(assert
 (and (and (= 1 x_stride_0_291) (= 1 y_stride_0_292)) (< 0 n_250)))
(assert
 (let (($x292 (forall ((ioo_257 Int) )(forall ((ioi_258 Int) )(forall ((ioo_293 Int) )(forall ((ioi_294 Int) )(let (($x271 (forall ((i0_310 Int) )(let (($x327 (exists ((ii_263 Int) )(and (and (<= 0 ii_263) (> 8 ii_263)) (= i0_310 ii_263)))
 ))
 (not (and $x327 (or $x327 $x327)))))
 ))
 (let (($x277 (forall ((i0_307 Int) )(let (($x327 (exists ((ii_263 Int) )(and (and (<= 0 ii_263) (> 8 ii_263)) (= i0_307 ii_263)))
 ))
 (not (and $x327 (or $x327 $x327)))))
 ))
 (let (($x303 (forall ((div_tmp_315 Int) (div_tmp_316 Int) )(let (($x356 (and (<= (* 4 div_tmp_316) (- div_tmp_315 1)) (< (- div_tmp_315 1) (* 4 (+ div_tmp_316 1))))))
 (let (($x326 (and (<= (* 8 div_tmp_315) (+ 7 n_250)) (< (+ 7 n_250) (* 8 (+ div_tmp_315 1))))))
 (let (($x250 (and $x326 $x356)))
 (=> $x250 (< ioo_293 div_tmp_316))))))
 ))
 (let (($x146 (forall ((div_tmp_313 Int) (div_tmp_314 Int) )(let (($x356 (and (<= (* 4 div_tmp_314) (- div_tmp_313 1)) (< (- div_tmp_313 1) (* 4 (+ div_tmp_314 1))))))
 (let (($x326 (and (<= (* 8 div_tmp_313) (+ 7 n_250)) (< (+ 7 n_250) (* 8 (+ div_tmp_313 1))))))
 (let (($x250 (and $x326 $x356)))
 (=> $x250 (< ioo_257 div_tmp_314))))))
 ))
 (let (($x148 (and (and (<= 0 ioo_257) $x146) (and (<= 0 ioi_258) (> 4 ioi_258)))))
 (let (($x305 (and (and $x148 (and (<= 0 ioo_293) $x303)) (and (<= 0 ioi_294) (> 4 ioi_294)))))
 (=> (and (and $x305 (< ioo_257 ioo_293)) (< ioi_294 ioi_258)) (and $x277 $x271)))))))))
 )
 )
 )
 ))
 (not $x292)))
(check-sat)

Here is the debug info when the statement is in reduce form:

assume  M(stride(x,0) == 1 ∧ stride(y,0) == 1 ∧ n > 0)
to verify
True
True
smtlib2
; benchmark generated from python API
(set-info :status unknown)
(declare-fun n_250 () Int)
(declare-fun y_stride_0_288 () Int)
(declare-fun x_stride_0_287 () Int)
(assert
 (and (and (= 1 x_stride_0_287) (= 1 y_stride_0_288)) (< 0 n_250)))
(assert
 (not true))
(check-sat)

It seems like there is two bugs? The case that is passing seems to be dispatching a query without any effects to verify?

Here is a minimal test that replicates the bug. z3-solver (4.12.4.0) reports unsat (instead of unknown on this test):

    @proc
    def foo(n: size, a: f32): 
        for i in seq(0, n):
            for j in seq(0, 4):
                a = a + 1.0

    foo = reorder_loops(foo, foo.find_loop("i"))

I think this might be an expected behavior. The body here doesn't commute with itself because the general case where one statement is reading a and writing a doesn't commute with another statement that reading a and writing a. Here they do precisely because it is a reduction. However, there is no logic in the analysis to detect that an assign represents a reduction.

One idea is to have an internal rewrite pass within the analysis the folds all assignments into a reduction form if possible before generating the SMT queries. The rewrites shouldn't propagate to the caller scheduling operation itself.