Memlet Error in RedundantArray (simplify)
Closed this issue · 2 comments
philip-paul-mueller commented
I found a bug in simplify, as far as I can tell it is located it is inside RedundantArray
The error is about a Memlet, that performs some reshaping.
A further issue that makes the whole thing complicated is, that the error depends on the processing order, the bug only appears in some cases, which makes it hard to debug.
I was able to find a minimal example that triggers the bug, but not always, so you have to run it multiple times.
import dace
"""
Minimal example for bug in `RedundantArray`
Essentially the SDFG perfroms the following computation.
def foo(A: dace.float64[6, 6, 6]) -> dace.float64[36, 1, 6]:
return A.reshape((36, 1, 6))
I located the error in the `RedundantArray` transformation, that screws up with the
Memlets. Furthermore, the input array is in FORTRAN order, but it should also work
with C order since it happens before code generation.
I M P O R T A N T
=================
The bug is not deterministic, it depends on the processing order!
It happens if node `a` is removed instead of Node `b`, never mind that that node could
also be removed. If you run the script and it does not fail, try again until it happens.
"""
sdfg = dace.SDFG("invalid_sdfg")
_, input_desc = sdfg.add_array(
"input",
shape=(6, 6, 6),
transient=False,
strides=(1, 6, 36),
dtype=dace.float64,
)
_, a_desc = sdfg.add_array(
"a",
shape=(6, 6, 6),
transient=True,
strides=(36, 6, 1),
dtype=dace.float64,
)
_, b_desc = sdfg.add_array(
"b",
shape=(36, 1, 6),
transient=True,
strides=(6, 6, 1),
dtype=dace.float64,
)
_, output_desc = sdfg.add_array(
"output",
shape=(36, 1, 6),
transient=False,
strides=(6, 6, 1 ),
dtype=dace.float64,
)
state = sdfg.add_state("state", is_start_block=True)
input_an = state.add_access("input")
a_an = state.add_access("a")
b_an = state.add_access("b")
output_an = state.add_access("output")
state.add_edge(
input_an,
None,
a_an,
None,
dace.Memlet.from_array("input", input_desc),
)
state.add_edge(
a_an,
None,
b_an,
None,
dace.Memlet.simple(
"a",
subset_str="0:6, 0:6, 0:6",
other_subset_str="0:36, 0, 0:6",
)
)
state.add_edge(
b_an,
None,
output_an,
None,
dace.Memlet.from_array("b", b_desc),
)
sdfg.validate()
sdfg.simplify(validate=False)
if len(sdfg.arrays) == 2:
print("All transients were removed, this is new, will it fail?")
elif "a" not in sdfg.arrays:
print("Array `a` was removed, in the past this indicated that the SDFG is invalid."
"\nLet's start validation to see what happens.")
elif "b" not in sdfg.arrays:
print("Array `b` was removed, in the past such an SDFG was valid, try again.")
else:
print("Something is fishy.")
sdfg.validate()
philip-paul-mueller commented
philip-paul-mueller commented
This issue seems to be solved by PR1603, but be aware of issue 1644 which is similar but happens in a different context.