Skip to content

Commit

Permalink
FIX-#2254: handling dict functions at groupby.agg improved (#2267)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev authored Oct 22, 2020
1 parent b514d6f commit 64b94f5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
43 changes: 31 additions & 12 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pandas
import pandas.core.groupby
from pandas.core.dtypes.common import is_list_like
from pandas.core.aggregation import reconstruct_func
import pandas.core.common as com

from modin.error_message import ErrorMessage
Expand Down Expand Up @@ -301,29 +302,40 @@ def aggregate(self, func=None, *args, **kwargs):
# so we throw a different message
raise NotImplementedError("axis other than 0 is not supported")
if isinstance(func, dict) or func is None:
if func is None:
func = {}
else:
if any(i not in self._df.columns for i in func.keys()):
from pandas.core.base import SpecificationError

raise SpecificationError("nested renamer is not supported")
def _reconstruct_func(func, **kwargs):
relabeling_required, func, new_columns, order = reconstruct_func(
func, **kwargs
)
# We convert to the string version of the function for simplicity.
func = {
k: v
if not callable(v) or v.__name__ not in dir(self)
else v.__name__
for k, v in func.items()
}
return relabeling_required, func, new_columns, order

relabeling_required, func_dict, new_columns, order = _reconstruct_func(
func, **kwargs
)

if any(i not in self._df.columns for i in func_dict.keys()):
from pandas.core.base import SpecificationError

raise SpecificationError("nested renamer is not supported")
if isinstance(self._by, type(self._query_compiler)):
by = list(self._by.columns)
else:
by = self._by
# We convert to the string version of the function for simplicity.
func_dict = {
k: v if not callable(v) or v.__name__ not in dir(self) else v.__name__
for k, v in func.items()
}

subset_cols = list(func_dict.keys()) + (
list(self._by.columns)
if isinstance(self._by, type(self._query_compiler))
and all(c in self._df.columns for c in self._by.columns)
else []
)
return type(self._df)(
result = type(self._df)(
query_compiler=self._df[subset_cols]._query_compiler.groupby_dict_agg(
by=by,
func_dict=func_dict,
Expand All @@ -332,6 +344,13 @@ def aggregate(self, func=None, *args, **kwargs):
drop=self._drop,
)
)

if relabeling_required:
result = result.iloc[:, order]
result.columns = new_columns

return result

if is_list_like(func):
return self._default_to_pandas(
lambda df, *args, **kwargs: df.aggregate(func, *args, **kwargs),
Expand Down
1 change: 0 additions & 1 deletion modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,6 @@ def test_groupby_multiindex():
df_equals(modin_df.groupby(by=by).count(), pandas_df.groupby(by=by).count())


@pytest.mark.skip("See Modin issue #2254 for details")
def test_agg_func_None_rename():
pandas_df = pandas.DataFrame(
{
Expand Down

0 comments on commit 64b94f5

Please sign in to comment.