Some Dask-SQL tests are extremely slow with dask-expr
rjzamora opened this issue · 5 comments
@charlesbluca shared a Dask-SQL test/snippet that seems to "hang" when the query-planning is enabled in dask.dataframe
. It turns out the operation does eventually finish, but that graph materialization is extremely slow in dask-expr for this particular expression graph.
It is certainly possible that Dask-SQL is producing an expression graph that is more complicated than necessary. However, it is definitely not complicated enough to warrant such an extreme slowdown.
Reproducer:
Original dask-sql reproducer
# Environment: mamba create -n dask-sql-hang-repro -c dask/label/dev dask-sql=2024.3.1
import pandas as pd
from dask_sql import Context
user_table_1 = pd.DataFrame({"user_id": [2, 1, 2, 3], "b": [3, 3, 1, 3]})
c = Context()
return_df = c.sql(
"""
SELECT
user_id,
b,
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS "O1",
FIRST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS "O2",
LAST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "O4",
SUM(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O5",
AVG(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O6",
COUNT(*) OVER (PARTITION BY user_id ORDER BY b) AS "O7",
COUNT(b) OVER (PARTITION BY user_id ORDER BY b) AS "O7b",
MAX(b) OVER (PARTITION BY user_id ORDER BY b) AS "O8",
MIN(b) OVER (PARTITION BY user_id ORDER BY b) AS "O9"
FROM user_table_1
""", dataframes={"user_table_1": user_table_1}
)
# Materializing the graph seems to "hang"
# (takes ~1 ms with query-planning disabled)
len(return_df.dask)
# Computing "works" for some reason, but is very slow (~18s)
# (takes ~100 ms with query-planning disabled)
return_df.compute()
import pandas as pd
import dask.dataframe as dd
def my_func(group, operand_col, new_col):
windowed_group = group.expanding(min_periods=1)
return group.assign(**{new_col: windowed_group[operand_col].mean()})
df = pd.DataFrame({"a": [2, 1, 2, 3], "b": [3, 3, 1, 3]})
ddf = dd.from_pandas(df, npartitions=2)
# starts getting slow around N=10
N = 10
for i in range(N):
group_column = f"group_{i}"
operand_column = f"operand_{i}"
new_column = f"mean_{i}"
# create and assign temporary columns
ddf = ddf.assign(**{group_column: 1})
ddf = ddf.assign(**{operand_column: ddf["a"] + ddf["b"]})
meta = ddf._meta.assign(**{new_column: 0.0})
# apply the function
ddf = ddf.groupby([group_column], dropna=False).apply(
my_func,
operand_column,
new_column,
meta=meta,
)
# drop the temporary columns
ddf = ddf.drop(columns=[group_column, operand_column]).reset_index(drop=True)
len(ddf.dask) # >100X slower with query-planning enabled
Other Notes:
- Calling
ddf._depth()
takes about 60 s, and returns100
(so not a terribly complex graph). - Calling
ddf.pprint()
also seems to "hang", so it's a bit hard to inspect the expression graph.
Known "Remedies":
As far as I can tell, the graph-materialization hang mostly goes away if Expr.lower_once
is cached. For example, everything is considerably faster when I hack in a simple caching pattern:
class Expr:
...
@functools.cached_property
def _lower_once_impl(self):
...
def lower_once(self):
return self._lower_once_impl
cc @fjetter @phofl - Seems like it makes sense to cache lowering behavior. WDYT?
and returns 100 (so not a terribly complex graph).
100 depth is a pretty complex graph since we're talking about expressions here. This is much more than I would naively assume given the "simple" SQL statement above.
Seems like it makes sense to cache lowering behavior. WDYT?
Without more investigation I'm -1 for introducing such a catch-all cache. Historically, most of these endless runtimes could be traced back to a minor bug or were dealt with by introducing more targeted caching.
FWIW I cannot reproduce the above since there doesn't appear to be a valid dask-sql package for OSX ARM
100 depth is a pretty complex graph since we're talking about expressions here. This is much more than I would naively assume given the "simple" SQL statement above.
Yes, If I remember correctly, dask-sql does an excessive amount of column renaming when the SQL query is mapped onto the dask/dataframe API. There is no doubt that the same logic can be expressed in a much simpler expression graph, but I'm assuming it would be a lot of work to change that. Therefore, these are the kinds of expression graphs dask-sql needs to produce for now. @charlesbluca may have thoughts on this.
Without more investigation I'm -1 for introducing such a catch-all cache.
Right, the "hack" I described above is just meant as a demonstration that caching seems to mitigate whatever the underlying issue is. We are technically caching the "lowered" version of every expression in _instances
anyway, so there are certainly more efficient solutions to avoid "re-lowing" the same expression many times (either targeted or general).
Update: @charlesbluca shared a dask-only reproducer, and I added it to the top-level description. I guess it makes sense that this pattern would create a bloated/repetitive graph.
I submitted #1059 to add basic caching. It is pretty clear that down-stream libraries (like dask-sql) may produce "deep" expression graphs, and and "diamond-like" branches in this deep graph will essentially multiply the size of the graph in the absence of caching.
For the python-only reproducer above, the depth of the expression graph is 71
, but every time multiple expressions depend on the same expression (e.g. Assign(Resetindex, group 9, 1)
), all of those dependents will re-lower that same expression. Each of those lowering paths are then further multiplied by similar patterns deeper in the graph.