dask/dask-expr

Some Dask-SQL tests are extremely slow with dask-expr

rjzamora opened this issue · 5 comments

@charlesbluca shared a Dask-SQL test/snippet that seems to "hang" when the query-planning is enabled in dask.dataframe. It turns out the operation does eventually finish, but that graph materialization is extremely slow in dask-expr for this particular expression graph.

It is certainly possible that Dask-SQL is producing an expression graph that is more complicated than necessary. However, it is definitely not complicated enough to warrant such an extreme slowdown.

Reproducer:

Original dask-sql reproducer
# Environment: mamba create -n dask-sql-hang-repro -c dask/label/dev dask-sql=2024.3.1

import pandas as pd
from dask_sql import Context

user_table_1 = pd.DataFrame({"user_id": [2, 1, 2, 3], "b": [3, 3, 1, 3]})

c = Context()

return_df = c.sql(
    """
SELECT
    user_id,
    b,
    ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY b) AS "O1",
    FIRST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b) AS "O2",
    LAST_VALUE(user_id*10 - b) OVER (PARTITION BY user_id ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "O4",
    SUM(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O5",
    AVG(user_id) OVER (PARTITION BY user_id ORDER BY b) AS "O6",
    COUNT(*) OVER (PARTITION BY user_id ORDER BY b) AS "O7",
    COUNT(b) OVER (PARTITION BY user_id ORDER BY b) AS "O7b",
    MAX(b) OVER (PARTITION BY user_id ORDER BY b) AS "O8",
    MIN(b) OVER (PARTITION BY user_id ORDER BY b) AS "O9"
FROM user_table_1
""", dataframes={"user_table_1": user_table_1}
)

# Materializing the graph seems to "hang"
# (takes ~1 ms with query-planning disabled)
len(return_df.dask)

# Computing "works" for some reason, but is very slow (~18s)
# (takes ~100 ms with query-planning disabled)
return_df.compute()
import pandas as pd
import dask.dataframe as dd

def my_func(group, operand_col, new_col):
    windowed_group = group.expanding(min_periods=1)
    return group.assign(**{new_col: windowed_group[operand_col].mean()})

df = pd.DataFrame({"a": [2, 1, 2, 3], "b": [3, 3, 1, 3]})
ddf = dd.from_pandas(df, npartitions=2)

# starts getting slow around N=10
N = 10

for i in range(N):
    group_column = f"group_{i}"
    operand_column = f"operand_{i}"
    new_column = f"mean_{i}"

    # create and assign temporary columns
    ddf = ddf.assign(**{group_column: 1})
    ddf = ddf.assign(**{operand_column: ddf["a"] + ddf["b"]})

    meta = ddf._meta.assign(**{new_column: 0.0})

    # apply the function
    ddf = ddf.groupby([group_column], dropna=False).apply(
        my_func,
        operand_column,
        new_column,
        meta=meta,
    )

    # drop the temporary columns
    ddf = ddf.drop(columns=[group_column, operand_column]).reset_index(drop=True)
    
len(ddf.dask)  # >100X slower with query-planning enabled

Other Notes:

  • Calling ddf._depth() takes about 60 s, and returns 100 (so not a terribly complex graph).
  • Calling ddf.pprint() also seems to "hang", so it's a bit hard to inspect the expression graph.

Known "Remedies":

As far as I can tell, the graph-materialization hang mostly goes away if Expr.lower_once is cached. For example, everything is considerably faster when I hack in a simple caching pattern:

class Expr:
    ...

    @functools.cached_property
    def _lower_once_impl(self):
        ...

    def lower_once(self):
        return self._lower_once_impl

cc @fjetter @phofl - Seems like it makes sense to cache lowering behavior. WDYT?

and returns 100 (so not a terribly complex graph).

100 depth is a pretty complex graph since we're talking about expressions here. This is much more than I would naively assume given the "simple" SQL statement above.

Seems like it makes sense to cache lowering behavior. WDYT?

Without more investigation I'm -1 for introducing such a catch-all cache. Historically, most of these endless runtimes could be traced back to a minor bug or were dealt with by introducing more targeted caching.

FWIW I cannot reproduce the above since there doesn't appear to be a valid dask-sql package for OSX ARM

100 depth is a pretty complex graph since we're talking about expressions here. This is much more than I would naively assume given the "simple" SQL statement above.

Yes, If I remember correctly, dask-sql does an excessive amount of column renaming when the SQL query is mapped onto the dask/dataframe API. There is no doubt that the same logic can be expressed in a much simpler expression graph, but I'm assuming it would be a lot of work to change that. Therefore, these are the kinds of expression graphs dask-sql needs to produce for now. @charlesbluca may have thoughts on this.

Without more investigation I'm -1 for introducing such a catch-all cache.

Right, the "hack" I described above is just meant as a demonstration that caching seems to mitigate whatever the underlying issue is. We are technically caching the "lowered" version of every expression in _instances anyway, so there are certainly more efficient solutions to avoid "re-lowing" the same expression many times (either targeted or general).

Update: @charlesbluca shared a dask-only reproducer, and I added it to the top-level description. I guess it makes sense that this pattern would create a bloated/repetitive graph.

I submitted #1059 to add basic caching. It is pretty clear that down-stream libraries (like dask-sql) may produce "deep" expression graphs, and and "diamond-like" branches in this deep graph will essentially multiply the size of the graph in the absence of caching.

For the python-only reproducer above, the depth of the expression graph is 71, but every time multiple expressions depend on the same expression (e.g. Assign(Resetindex, group 9, 1)), all of those dependents will re-lower that same expression. Each of those lowering paths are then further multiplied by similar patterns deeper in the graph.

Full expression graph

expr_graph