DataFrame subclass lost in `groupby.agg` with `split_out` set.
TomAugspurger opened this issue · 1 comments
Describe the issue:
As part of geopandas/dask-geopandas#285, we found that dask-expr will lose the type of a pandas DataFrame subclass in groupby.agg
if (and only if?) the split_out
parameter is used.
Minimal Complete Verifiable Example:
Given this file:
# file: test.py
import dask.dataframe.backends
import pandas as pd
import dask_expr as dx
import dask.dataframe as dd
from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty
from dask.dataframe.core import get_parallel_type
import dask.dataframe.backends
dask.config.set(scheduler="single-threaded")
class MySeries(pd.Series):
@property
def _constructor(self):
return MySeries
@property
def _constructor_expanddim(self):
return MyDataFrame
class MyDataFrame(pd.DataFrame):
@property
def _constructor(self):
return MyDataFrame
@property
def _constructor_sliced(self):
return MySeries
class MyIndex(pd.Index): ...
class MyDaskSeries(dx.Series):
_partition_type = MySeries
class MyDaskDataFrame(dx.DataFrame):
_partition_type = MyDataFrame
class MyDaskIndex(dx.Index):
_partition_type = MyIndex
# Unclear if any of get_parallel_type and make_meta_dispatch are needed.
# Reproduces with or without them.
@get_parallel_type.register(MyDataFrame)
def get_parallel_type_dataframe(df):
return MyDataFrame
@get_parallel_type.register(MySeries)
def get_parallel_type_series(s):
return MyDaskSeries
@get_parallel_type.register(MyIndex)
def get_parallel_type_index(ind):
return MyDaskIndex
@make_meta_dispatch.register(MyDataFrame)
def make_meta_dataframe(df, index=None):
return df.head(0)
@make_meta_dispatch.register(MySeries)
def make_meta_series(s, index=None):
return s.head(0)
@make_meta_dispatch.register(MyIndex)
def make_meta_index(ind, index=None):
return ind[:0]
@meta_nonempty.register(MyDataFrame)
def make_meta_nonempty_dataframe(x):
return MyDataFrame(dask.dataframe.backends.meta_nonempty_dataframe(x))
df = dx.from_dict(
{"a": [1, 1, 2, 2], "b": [1, 2, 3, 4]}, npartitions=4, constructor=MyDataFrame
)
a = df.groupby("a").agg("first")
b = df.groupby("a").agg("first", split_out=2)
print("split-out=None", type(a.compute()))
print("split-out=2 ", type(b.compute()))
running that produces
$ python test.py
split-out=None <class '__main__.MyDataFrame'>
split-out=2 <class 'pandas.core.frame.DataFrame'>
I would expect the type
there to be __main__.MyDataFrame
regardless of split_out
.
Anything else we need to know?:
Environment:
dask 2024.4.1
dask-expr 1.0.11
Edit: I made one addition to the script: adding a @meta_nonempty.register(MyDataFrame)
. I noticed that in DecomposableGroupbyAggregation.combine
and DecomposableGroupbyAggregation.aggregate
the types were regular pandas DataFrames, instead of the subclass.
Registering that meta_nonempty
does keep it as MyDataFrame
initially. I put some print statements in those methods to print the type of inputs[0]
and type(_concat(inputs))
and get
combine <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class '__main__.MyDataFrame'> <class '__main__.MyDataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
aggregate <class 'pandas.core.frame.DataFrame'> <class 'pandas.core.frame.DataFrame'>
split-out=2 <class 'pandas.core.frame.DataFrame'>
So initially we're OK, but by the time we do the final aggregate
we've lost the subclass.
This is a shuffle issue (and also present on the current implementation if I am not mistaken?)
df.shuffle("a") will lose your type, that's what we do under the hood if split_out != 1. shuffle_method="tasks" keeps it, disk and p2p lose it.
I can patch that so that your resulting DataFrame will have the correct type, but I don't know if we can guarantee that we keep whatever you might add to the subclass through shuffles without you overriding the shuffle specific methods