spcl/dace

Memlet Error in RedundantArray (simplify)

Closed this issue · 2 comments

I found a bug in simplify, as far as I can tell it is located it is inside RedundantArray
The error is about a Memlet, that performs some reshaping.
A further issue that makes the whole thing complicated is, that the error depends on the processing order, the bug only appears in some cases, which makes it hard to debug.
I was able to find a minimal example that triggers the bug, but not always, so you have to run it multiple times.

import dace

"""
Minimal example for bug in `RedundantArray`

Essentially the SDFG perfroms the following computation.
def foo(A: dace.float64[6, 6, 6]) -> dace.float64[36, 1, 6]:
    return A.reshape((36, 1, 6))

I located the error in the `RedundantArray` transformation, that screws up with the
Memlets. Furthermore, the input array is in FORTRAN order, but it should also work
with C order since it happens before code generation.

I M P O R T A N T
=================
The bug is not deterministic, it depends on the processing order!
It happens if node `a` is removed instead of Node `b`, never mind that that node could
also be removed. If you run the script and it does not fail, try again until it happens.
"""




sdfg = dace.SDFG("invalid_sdfg")

_, input_desc = sdfg.add_array(
        "input",
        shape=(6, 6, 6),
        transient=False,
        strides=(1, 6, 36),
        dtype=dace.float64,
)
_, a_desc = sdfg.add_array(
        "a",
        shape=(6, 6, 6),
        transient=True,
        strides=(36, 6, 1),
        dtype=dace.float64,
)
_, b_desc = sdfg.add_array(
        "b",
        shape=(36, 1, 6),
        transient=True,
        strides=(6, 6, 1),
        dtype=dace.float64,
)
_, output_desc = sdfg.add_array(
        "output",
        shape=(36, 1, 6),
        transient=False,
        strides=(6, 6, 1 ),
        dtype=dace.float64,
)

state = sdfg.add_state("state", is_start_block=True)
input_an = state.add_access("input")
a_an = state.add_access("a")
b_an = state.add_access("b")
output_an = state.add_access("output")

state.add_edge(
        input_an,
        None,
        a_an,
        None,
        dace.Memlet.from_array("input", input_desc),
)

state.add_edge(
        a_an,
        None,
        b_an,
        None,
        dace.Memlet.simple(
            "a",
            subset_str="0:6, 0:6, 0:6",
            other_subset_str="0:36, 0, 0:6",
        )
)

state.add_edge(
        b_an,
        None,
        output_an,
        None,
        dace.Memlet.from_array("b", b_desc),
)

sdfg.validate()
sdfg.simplify(validate=False)

if len(sdfg.arrays) == 2:
    print("All transients were removed, this is new, will it fail?")

elif "a" not in sdfg.arrays:
    print("Array `a` was removed, in the past this indicated that the SDFG is invalid."
          "\nLet's start validation to see what happens.")

elif "b" not in sdfg.arrays:
    print("Array `b` was removed, in the past such an SDFG was valid, try again.")

else:
    print("Something is fishy.")

sdfg.validate()

A deterministic test can be found here, but it should be fixed in PR1603

This issue seems to be solved by PR1603, but be aware of issue 1644 which is similar but happens in a different context.