spcl/dace

Codegen: View Shadows Array in NestedSDFG

philip-paul-mueller opened this issue · 1 comments

I have a strange error, after some experiments I have concluded the following:

  • There is a nested SDFG that has a transient of name, e.g. "X".
  • In the parent SDFG there is a view that has the same name, i.e. "X"

In that case the code generator will generate for the nested SDFG:

	X = new float[N];

without declaring X, i.e. the line float* X, that should precede the allocation, is missing.

Currently I have only a big SDFG that has this pattern.
The error vanishes, once I rename the view in the top level SDFG or turn it into an array.
I added the SDFG and scripts that shows I "fixed" the error.

I will also try to come up with a smaller example.
example.zip

I have now generated a reproducer:

import dace

dtype = dace.float64
shape = (10, 10)

def make_nsdfg():
    """
    Generates a nested SDFG with that runs the following calculation:

    ```python
    def nested_comp(
            input1: dace.float64[10, 10],
            input2: dace.float64[10, 10],
    ) -> dace.float64[10, 10]:
        X = input1 + input2
        return X * input2
    ```

    The inportant thing here is that the temporary is called `X`.
    """
    sdfg = dace.SDFG("Nested_SDFG")
    
    array_names = [("input1", False), ("input2", False), ("output", False), ("X", True)]
    array_descs = {}
    for array_name, is_transient in array_names:
        _, desc = sdfg.add_array(
            array_name,
            dtype=dtype,
            shape=shape,
            transient=is_transient,
        )
        array_descs[array_name] = desc

    state1 = sdfg.add_state("init_state", is_start_block=True)
    state2 = sdfg.add_state_after(state1, "out_state")

    state1.add_mapped_tasklet(
        "first_addition",
        map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
        code="__out = __in1 + __in2",
        inputs={"__in1": dace.Memlet("input1[__i0, __i1]"),
                "__in2": dace.Memlet("input2[__i0, __i1]"),
        },
        outputs={"__out": dace.Memlet("X[__i0, __i1]")},
        external_edges=True,
    )
    state2.add_mapped_tasklet(
        "second_addition",
        map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
        code="__out = __in1 * __in2",
        inputs={"__in1": dace.Memlet("X[__i0, __i1]"),
                "__in2": dace.Memlet("input2[__i0, __i1]"),
        },
        outputs={"__out": dace.Memlet("output[__i0, __i1]")},
        external_edges=True,
    )
    sdfg.validate()
    sdfg.simplify()

    return sdfg


def make_main_sdfg(
    apply_fix: int | None = None,
):
    """Generate the failing SDFG.

    Essentially the computation:

    ```
    def comp(
            input1: dace.float64[10, 10],
            input2: dace.float64[10, 10],
            input3: dace.float64[10, 10],
    ) -> dace.float64[10, 10]:
        nested_output = nested_comp(input1, input2)
        output = np.zeros_like(nested_output)
        X = output.view()
        X = nested_output * input3

        return output
    ```

    The important thing is that the view in the above computation has the same
    name as the temporary that is used inside the nested computation.
    This SDFG will pass validation and code generation, but it is not able
    compile it. The aviable fixes, see below, esspecially `2` suggests that
    it is a bug in code generator.

    It is possible to apply several fixes for this issue:
    - `1`: Means that the nested SDFG is put into a sperate state.
    - `2`: Give the view in the top SDFG a different name.
    - `3`: Using an array instead of a view in the top SDFG.
    """
    
    sdfg = dace.SDFG("main_SDFG")
    array_names = [("input1", False), ("input2", False), ("input3", False), ("output", False), ("nested_output", True)]
    array_descs = {}
    for array_name, is_transient in array_names:
        _, desc = sdfg.add_array(
            array_name,
            dtype=dtype,
            shape=shape,
            transient=is_transient,
        )
        array_descs[array_name] = desc

    if(apply_fix == 2):
        view_name = "not_X"
    else:
        # Same name as inside the Nested SDFG.
        view_name = "X"

    # Now generate the view that we need
    if(apply_fix == 3):
        sdfg.add_array(
            view_name,
            dtype=dtype,
            shape=shape,
            transient=True,
        )
    else:
        sdfg.add_view(
            view_name,
            shape=shape,
            dtype=dtype,
        )

    state1 = sdfg.add_state("nested_host_state_init", is_start_block=True)
    nested_sdfg = make_nsdfg()

    nested_inputs = {
        "input1": "input1",
        "input2": "input2"
    }
    nested_outputs = {
        "nested_output": "output",
    }
    nsdfg = state1.add_nested_sdfg(
        nested_sdfg,
        parent=sdfg,
        inputs=set(nested_inputs.values()),
        outputs=set(nested_outputs.values()),
    )

    for in_parent, in_nested in nested_inputs.items():
        state1.add_edge(
            state1.add_read(in_parent),
            None,
            nsdfg,
            in_nested,
            dace.Memlet.from_array(in_parent, array_descs[in_parent]),
        )
    nested_outputs_ac = []
    for out_parent, out_nested in nested_outputs.items():
        nested_outputs_ac.append(state1.add_access(out_parent))
        state1.add_edge(
            nsdfg,
            out_nested,
            nested_outputs_ac[-1],
            None,
            dace.Memlet.from_array(out_parent, array_descs[out_parent]),
        )
    assert len(nested_outputs_ac) == 1

    if(apply_fix == 1):
        state2 = sdfg.add_state_after(state1, "second_main_state")
        nested_output_ac = state2.add_access("nested_output")
    else:
        state2 = state1
        nested_output_ac = nested_outputs_ac[0]
    
    state2.add_mapped_tasklet(
        "second_addition_in_map",
        map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
        code="__out = __in1 * __in2",
        inputs={"__in1": dace.Memlet("nested_output[__i0, __i1]"),
                "__in2": dace.Memlet("input3[__i0, __i1]"),
        },
        input_nodes={"nested_output": nested_output_ac, "input3": state2.add_access("input3")},
        outputs={"__out": dace.Memlet(view_name + "[__i0, __i1]")},
        external_edges=True,
    )

    # Find the access node of the view
    all_access_nodes = [node for node in state2.nodes() if isinstance(node, dace.nodes.AccessNode)]
    view_access_node = next(node for node in all_access_nodes if node.data == view_name)

    state2.add_edge(
        view_access_node,
        "views",
        state2.add_write("output"),
        None,
        dace.Memlet.from_array("output", array_descs["output"]),
    )
    sdfg.validate()

    return sdfg


def can_be_compiled(
    sdfg: dace.SDFG,
) -> bool:
    import warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        try:
            _ = sdfg.compile()
        except (dace.codegen.exceptions.CompilationError, dace.codegen.exceptions.CodegenError):
            return False
        return True

# Now test everything
for fix in [None, 1, 2, 3]:
    sdfg = make_main_sdfg(fix)

    if fix is None:
        # No fix so we expect it to fail
        assert not can_be_compiled(sdfg), "It seems the bug has vanisched, did you fixed it."
    else:
        assert can_be_compiled(sdfg), f"Expected that fix `{fix}` circumvent the problem."