Skip to content

Commit

Permalink
FIX-#2463: Added test with callable functions as aggregate argument (#…
Browse files Browse the repository at this point in the history
…2503)

Signed-off-by: Gregory Shimansky <gregory.shimansky@intel.com>
  • Loading branch information
gshimansky authored Dec 2, 2020
1 parent 372422b commit 299ba18
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
15 changes: 14 additions & 1 deletion modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pandas.core.dtypes.common import is_list_like
from pandas.core.aggregation import reconstruct_func
import pandas.core.common as com
from types import BuiltinFunctionType

from modin.error_message import ErrorMessage
from modin.utils import _inherit_docstrings, try_cast_to_pandas
Expand Down Expand Up @@ -358,6 +359,13 @@ def aggregate(self, func=None, *args, **kwargs):
# so we throw a different message
raise NotImplementedError("axis other than 0 is not supported")

if (
callable(func)
and isinstance(func, BuiltinFunctionType)
and func.__name__ in dir(self)
):
func = func.__name__

relabeling_required = False
if isinstance(func, dict) or func is None:

Expand Down Expand Up @@ -389,6 +397,12 @@ def _reconstruct_func(func, **kwargs):
*args,
**kwargs,
)
elif callable(func):
return self._apply_agg_function(
lambda grp, *args, **kwargs: grp.aggregate(func, *args, **kwargs),
*args,
**kwargs,
)
elif isinstance(func, str):
# Using "getattr" here masks possible AttributeError which we throw
# in __getattr__, so we should call __getattr__ directly instead.
Expand All @@ -398,7 +412,6 @@ def _reconstruct_func(func, **kwargs):

result = self._apply_agg_function(
func,
drop=self._as_index,
*args,
**kwargs,
)
Expand Down
23 changes: 20 additions & 3 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,16 @@ def test_mixed_dtypes_groupby(as_index):
eval_skew(modin_groupby, pandas_groupby)

agg_functions = [
lambda df: df.sum(),
"min",
min,
"max",
max,
sum,
{"col2": "sum"},
{"col2": sum},
{"col2": "max", "col4": "sum", "col5": "min"},
{"col2": max, "col4": sum, "col5": "min"},
]
for func in agg_functions:
eval_agg(modin_groupby, pandas_groupby, func)
Expand Down Expand Up @@ -353,7 +359,7 @@ def maybe_get_columns(df, by):
eval_var(modin_groupby, pandas_groupby)
eval_skew(modin_groupby, pandas_groupby)

agg_functions = ["min", "max"]
agg_functions = [lambda df: df.sum(), "min", "max", min, sum]
for func in agg_functions:
eval_agg(modin_groupby, pandas_groupby, func)
eval_aggregate(modin_groupby, pandas_groupby, func)
Expand Down Expand Up @@ -486,8 +492,11 @@ def test_single_group_row_groupby():
eval_std(modin_groupby, pandas_groupby)

agg_functions = [
lambda df: df.sum(),
"min",
"max",
max,
sum,
{"col2": "sum"},
{"col2": "max", "col4": "sum", "col5": "min"},
]
Expand Down Expand Up @@ -606,7 +615,15 @@ def test_large_row_groupby(is_by_category):
# eval_prod(modin_groupby, pandas_groupby) causes overflows
eval_std(modin_groupby, pandas_groupby)

agg_functions = ["min", "max", {"A": "sum"}, {"A": "max", "B": "sum", "C": "min"}]
agg_functions = [
lambda df: df.sum(),
"min",
"max",
min,
sum,
{"A": "sum"},
{"A": "max", "B": "sum", "C": "min"},
]
for func in agg_functions:
eval_agg(modin_groupby, pandas_groupby, func)
eval_aggregate(modin_groupby, pandas_groupby, func)
Expand Down Expand Up @@ -863,7 +880,7 @@ def test_series_groupby(by, as_index_series_or_dataframe):
eval_var(modin_groupby, pandas_groupby)
eval_skew(modin_groupby, pandas_groupby)

agg_functions = ["min", "max"]
agg_functions = [lambda df: df.sum(), "min", "max", max, sum]
for func in agg_functions:
eval_agg(modin_groupby, pandas_groupby, func)
eval_aggregate(modin_groupby, pandas_groupby, func)
Expand Down

0 comments on commit 299ba18

Please sign in to comment.