dask/dask-expr

DataFrame subclass lost in `groupby.agg` with `split_out` set.

TomAugspurger opened this issue · 1 comments

Describe the issue:

As part of geopandas/dask-geopandas#285, we found that dask-expr will lose the type of a pandas DataFrame subclass in groupby.agg if (and only if?) the split_out parameter is used.

Minimal Complete Verifiable Example:

Given this file:

# file: test.py
import dask.dataframe.backends
import pandas as pd

import dask_expr as dx
import dask.dataframe as dd
from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty
from dask.dataframe.core import get_parallel_type
import dask.dataframe.backends


dask.config.set(scheduler="single-threaded")

class MySeries(pd.Series):
    @property
    def _constructor(self):
        return MySeries

    @property
    def _constructor_expanddim(self):
        return MyDataFrame


class MyDataFrame(pd.DataFrame):
    @property
    def _constructor(self):
        return MyDataFrame

    @property
    def _constructor_sliced(self):
        return MySeries


class MyIndex(pd.Index): ...


class MyDaskSeries(dx.Series):
    _partition_type = MySeries


class MyDaskDataFrame(dx.DataFrame):
    _partition_type = MyDataFrame


class MyDaskIndex(dx.Index):
    _partition_type = MyIndex


# Unclear if any of get_parallel_type and make_meta_dispatch are needed.
# Reproduces with or without them.
@get_parallel_type.register(MyDataFrame)
def get_parallel_type_dataframe(df):
    return MyDataFrame


@get_parallel_type.register(MySeries)
def get_parallel_type_series(s):
    return MyDaskSeries


@get_parallel_type.register(MyIndex)
def get_parallel_type_index(ind):
    return MyDaskIndex


@make_meta_dispatch.register(MyDataFrame)
def make_meta_dataframe(df, index=None):
    return df.head(0)


@make_meta_dispatch.register(MySeries)
def make_meta_series(s, index=None):
    return s.head(0)


@make_meta_dispatch.register(MyIndex)
def make_meta_index(ind, index=None):
    return ind[:0]


@meta_nonempty.register(MyDataFrame)
def make_meta_nonempty_dataframe(x):
    return MyDataFrame(dask.dataframe.backends.meta_nonempty_dataframe(x))


df = dx.from_dict(
    {"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}, npartitions=4, constructor=MyDataFrame
)
a = df.groupby("a").agg("first")
b = df.groupby("a").agg("first", split_out=2)

print("split-out=None", type(a.compute()))
print("split-out=2   ", type(b.compute()))

running that produces

$ python test.py
split-out=None <class '__main__.MyDataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

I would expect the type there to be __main__.MyDataFrame regardless of split_out.

Anything else we need to know?:

Environment:

dask               2024.4.1
dask-expr          1.0.11

Edit: I made one addition to the script: adding a @meta_nonempty.register(MyDataFrame). I noticed that in DecomposableGroupbyAggregation.combine and DecomposableGroupbyAggregation.aggregate the types were regular pandas DataFrames, instead of the subclass.

Registering that meta_nonempty does keep it as MyDataFrame initially. I put some print statements in those methods to print the type of inputs[0] and type(_concat(inputs)) and get

combine <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
split-out=2    <class 'pandas.core.frame.DataFrame'>

So initially we're OK, but by the time we do the final aggregate we've lost the subclass.

This is a shuffle issue (and also present on the current implementation if I am not mistaken?)

df.shuffle("a") will lose your type, that's what we do under the hood if split_out != 1. shuffle_method="tasks" keeps it, disk and p2p lose it.

I can patch that so that your resulting DataFrame will have the correct type, but I don't know if we can guarantee that we keep whatever you might add to the subclass through shuffles without you overriding the shuffle specific methods