[BUG] Imperfect shape query propagation in the presence of dead code
Opened this issue · 2 comments
Description
The fusion pipeline tries to express array shape queries in expressions (with the shape
primitive) in terms of shapes of argument arrays if not doing so would prevent fusion. However, the algorithm seems to miss certain cases; the case I ran into involves a shape query of an array that is not used anywhere outside of that shape query (and is thus dead code).
Steps to reproduce
Consider the following code:
import Data.Array.Accelerate
shapePropTest1 =
let a7 = use (fromList (Z :. (1 :: Int) :. (1 :: Int)) [5.0 :: Float])
a8 = map (\x0 -> T2 (x0 * x0) x0) a7
a9 = map (\(T2 x0 _) -> x0) a8
in zipWith (+)
(generate (shape a9) (\_ -> 1.0))
(map (\(T2 _ tup) -> tup) a8)
shapePropTest2 =
let a7 = use (fromList (Z :. (1 :: Int) :. (1 :: Int)) [5.0 :: Float])
a8 = map (\x0 -> T2 (x0 * x0) x0) a7
a9 = map (\(T2 x0 _) -> x0) a8
in zipWith (+)
(generate (shape a8) (\_ -> 1.0))
(map (\(T2 _ tup) -> tup) a8)
main :: IO ()
main = print shapePropTest1 >> print shapePropTest2
The only difference between these two programs is the argument to shape
. Note that a8
and a9
have the same shape by definition of map
.
The first program show
s to:
let
a0 = map (\x0 -> T2 (x0 * x0) x0) (use (Matrix (Z :. 1 :. 1) [ 5.0]))
a1 = map (\(T2 x0 _) -> x0) a0
a2 = map (\(_, x0) -> x0) a0
in
generate
(let T2 x0 x1 = shape a2 T2 x2 x3 = shape a1 in T2 (min x2 x0) (min x3 x1))
(\(T2 x0 x1) -> 1.0 + a2 ! (T2 x0 x1))
The second program show
s to:
let a0 = use (Matrix (Z :. 1 :. 1) [ 5.0])
in
generate
(let T2 x0 x1 = shape a0 T2 x2 x3 = shape a0 in T2 (min x2 x0) (min x3 x1))
(\(T2 x0 x1) -> 1.0 + a0 ! (T2 x0 x1))
The first program exhibits suboptimal fusion, even though it is equivalent to the second.
Expected behaviour
Both programs should fuse to the same result, eliminating all map
calls.
Your environment
- Accelerate:
accelerate-1.3.0.0
on Hackage, or commit162a779f
. - Accelerate backend(s): n/a
- GHC: Stack LTS 16.12 (ghc 8.8.4)
- OS: Arch Linux
Another example of similar behaviour, except now there is no dead code, but there is a more complicated scalar expression to reason about.
let a = use (fromList (Z :. (1 :: Int)) [(1::Float,2::Float)])
b = map fst a
in generate (I1 (let I1 n1 = shape b ; I1 n2 = shape a in min n1 n2))
(\(I1 i) -> b ! I1 i)
This show
s as:
let
a0 = use (Vector (Z :. 1) [(((), 1.0), 2.0)])
a1 = map (\(T2 x0 _) -> x0) a0
in
generate (min (let T1 x0 = shape a1 in x0) (let T1 x0 = shape a0 in x0)) (\(T1 x0) -> a1 ! x0)
whereas an optimal fusion result would have been map fst (use (Vector ...))
.
Notable is that if fst
is replaced by (*2)
(and the type of a
with Vector Float
), then the map
does fuse. Is this an instance of the thing I ran into earlier, where nodes like map fst
in the fused AST are not actually compiled as manifest arrays but rather reduce to manipulation of SoA-form arrays?
Two more programs that produce (somewhat) unexpected results:
zipWith (+)
(zipWith (\x y -> x * 2 * y) a a)
(map (\x -> x * x) a)
which show
s as:
\a0 ->
transform
(let T1 x0 = shape a0 ; T1 x1 = shape a0 ; T1 x2 = shape a0 in (min (min x1 x0) x2))
(\(T1 x0) -> x0)
(\x0 -> 2.0 * x0 * x0 + x0 * x0)
a0
which could have been \a -> map (\x -> 2 * x * x + x * x) a
.
Program 2:
zipWith (+) a (map (\x -> x) a)
which show
s as:
\a0 ->
let a1 = map (\x0 -> x0) a0
in
generate
(let T1 x0 = shape a1 ; T1 x1 = shape a0 in (min x1 x0))
(\(T1 x0) -> a0 ! x0 + a1 ! x0)
which could have been \a -> map (\x -> x * x) a
, or indeed \a -> generate (shape a) (\(T1 x) -> a ! x + a ! x)
.
I'm just collecting these for future reference at this point.