Error in CompositeFusion (auto optimize)
Opened this issue · 0 comments
philip-paul-mueller commented
The introduction of the new MapFusion implementation revealed some bugs (most likely in CompositeFusion
). Since the new transformation has a compatibility mode, it must be disabled first, see patch bellow.
The error surfaces in tests/npbench/polybench/correlation_test.py
and the main cause is most likely the usage of a dynamic memlet. If the update was changed to stdev = np.where(stddev <= 0.1, 1.0, stddev)
then the error goes away.
After map fusion has run the SDFG has the following structure (where x
is stddev
).
The reason is, that now the before-after dependency is not given through an explicit or rather direct data flow, but through the dependency of the maps, which the transformation fails to pick up.
From 213c36d5ba337bfc1541941b0bc459f94270dcf5 Mon Sep 17 00:00:00 2001
From: Philip Mueller <philip.mueller@cscs.ch>
Date: Fri, 6 Sep 2024 14:54:20 +0200
Subject: [PATCH] Disable strict dataflow in auto optimizer.
---
dace/transformation/auto/auto_optimize.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py
index 09ff481e3..db0ba5ef2 100644
--- a/dace/transformation/auto/auto_optimize.py
+++ b/dace/transformation/auto/auto_optimize.py
@@ -52,6 +52,9 @@ def greedy_fuse(graph_or_subgraph: GraphViewType,
:param permutations_only: Disallow splitting of maps during MultiExpansion stage
:param expand_reductions: Expand all reduce nodes before fusion
"""
+ validate_all = True
+ validate = True
+ strict_dataflow = True
debugprint = config.Config.get_bool('debugprint')
if isinstance(graph_or_subgraph, ControlFlowRegion):
if isinstance(graph_or_subgraph, SDFG):
@@ -61,7 +64,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType,
# We have to use `strict_dataflow` because it is known that `CompositeFusion`
# has problems otherwise.
graph_or_subgraph.apply_transformations_repeated(
- MapFusion(strict_dataflow=True),
+ MapFusion(strict_dataflow=strict_dataflow),
validate_all=validate_all,
)
@@ -82,7 +85,7 @@ def greedy_fuse(graph_or_subgraph: GraphViewType,
if isinstance(graph_or_subgraph, SDFGState):
sdfg = graph_or_subgraph.parent
sdfg.apply_transformations_repeated(
- MapFusion(strict_dataflow=True),
+ MapFusion(strict_dataflow=strict_dataflow),
validate_all=validate_all,
)
graph = graph_or_subgraph
--
2.46.0