From 085a289e73ab15870709c8cfe3dd43e344acb3ca Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev Date: Wed, 16 Dec 2020 17:45:29 +0300 Subject: [PATCH] FIX-#2482: improved handling non-str 'by' Signed-off-by: Dmitry Chigarev --- modin/backends/pandas/query_compiler.py | 4 +-- .../functions/groupby_function.py | 6 ++-- modin/pandas/dataframe.py | 12 +++---- modin/pandas/groupby.py | 20 +++++++++--- modin/pandas/test/test_groupby.py | 31 +++++++++++++++++++ 5 files changed, 59 insertions(+), 14 deletions(-) diff --git a/modin/backends/pandas/query_compiler.py b/modin/backends/pandas/query_compiler.py index 6f62405a735..bce56da3901 100644 --- a/modin/backends/pandas/query_compiler.py +++ b/modin/backends/pandas/query_compiler.py @@ -29,7 +29,7 @@ from modin.backends.base.query_compiler import BaseQueryCompiler from modin.error_message import ErrorMessage -from modin.utils import try_cast_to_pandas, wrap_udf_function +from modin.utils import try_cast_to_pandas, wrap_udf_function, hashable from modin.data_management.functions import ( Function, FoldFunction, @@ -2555,7 +2555,7 @@ def is_reduce_fn(fn, deep_level=0): else: if not isinstance(by, list): by = [by] - internal_by = [o for o in by if isinstance(o, str) and o in self.columns] + internal_by = [o for o in by if hashable(o) and o in self.columns] internal_qc = ( [self.getitem_column_array(internal_by)] if len(internal_by) else [] ) diff --git a/modin/data_management/functions/groupby_function.py b/modin/data_management/functions/groupby_function.py index fef5a8a8ecf..b94bdcec173 100644 --- a/modin/data_management/functions/groupby_function.py +++ b/modin/data_management/functions/groupby_function.py @@ -14,7 +14,7 @@ import pandas from .mapreducefunction import MapReduceFunction -from modin.utils import try_cast_to_pandas +from modin.utils import try_cast_to_pandas, hashable class GroupbyReduceFunction(MapReduceFunction): @@ -113,7 +113,9 @@ def caller( numeric_only=True, **kwargs, ): - if not isinstance(by, (type(query_compiler), str)): + if not (isinstance(by, (type(query_compiler)) or hashable(by))) or isinstance( + by, pandas.Grouper + ): by = try_cast_to_pandas(by, squeeze=True) default_func = ( (lambda grp: grp.agg(map_func)) diff --git a/modin/pandas/dataframe.py b/modin/pandas/dataframe.py index 8ebf85a4ae1..67f5f4c035d 100644 --- a/modin/pandas/dataframe.py +++ b/modin/pandas/dataframe.py @@ -365,7 +365,7 @@ def groupby( if callable(by): by = self.index.map(by) - elif isinstance(by, str): + elif hashable(by) and not isinstance(by, pandas.Grouper): drop = by in self.columns idx_name = by if self._query_compiler.has_multiindex( @@ -374,7 +374,7 @@ def groupby( # In this case we pass the string value of the name through to the # partitions. This is more efficient than broadcasting the values. pass - else: + elif level is None: by = self.__getitem__(by)._query_compiler elif isinstance(by, Series): drop = by._parent is self @@ -384,7 +384,7 @@ def groupby( # fastpath for multi column groupby if axis == 0 and all( ( - (isinstance(o, str) and (o in self)) + (hashable(o) and (o in self)) or isinstance(o, Series) or (is_list_like(o) and len(o) == len(self.axes[axis])) ) @@ -395,7 +395,7 @@ def groupby( internal_by, external_by = [], [] for current_by in by: - if isinstance(current_by, str): + if hashable(current_by): internal_by.append(current_by) elif isinstance(current_by, Series): if current_by._parent is self: @@ -414,7 +414,7 @@ def groupby( else: mismatch = len(by) != len(self.axes[axis]) if mismatch and all( - isinstance(obj, str) + hashable(obj) and ( obj in self or obj in self._query_compiler.get_index_names(axis) ) @@ -424,7 +424,7 @@ def groupby( # we default to pandas in this case. pass elif mismatch and any( - isinstance(obj, str) and obj not in self.columns for obj in by + hashable(obj) and obj not in self.columns for obj in by ): names = [o.name if isinstance(o, Series) else o for o in by] raise KeyError(next(x for x in names if x not in self)) diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index 0328701addb..2abeaa375b8 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -32,7 +32,12 @@ from collections.abc import Iterable from modin.error_message import ErrorMessage -from modin.utils import _inherit_docstrings, try_cast_to_pandas, wrap_udf_function +from modin.utils import ( + _inherit_docstrings, + try_cast_to_pandas, + wrap_udf_function, + hashable, +) from modin.backends.base.query_compiler import BaseQueryCompiler from modin.config import IsExperimental from .series import Series @@ -79,7 +84,7 @@ def __init__( not isinstance(by, type(self._query_compiler)) and axis == 0 and all( - (isinstance(obj, str) and obj in self._query_compiler.columns) + (hashable(obj) and obj in self._query_compiler.columns) or isinstance(obj, type(self._query_compiler)) or is_list_like(obj) for obj in self._by @@ -324,7 +329,7 @@ def __getitem__(self, key): if ( self._is_multi_by and isinstance(self._by, list) - and not all(isinstance(o, str) for o in self._by) + and not all(hashable(o) and o in self._df for o in self._by) ): raise NotImplementedError( "Column lookups on GroupBy with arbitrary Series in by" @@ -809,7 +814,14 @@ def _index_grouped(self): # aware. ErrorMessage.catch_bugs_and_request_email(self._axis == 1) ErrorMessage.default_to_pandas("Groupby with multiple columns") - if isinstance(by, list) and all(isinstance(o, str) for o in by): + if isinstance(by, list) and all( + hashable(o) + and ( + o in self._df + or o in self._df._query_compiler.get_index_names(self._axis) + ) + for o in by + ): pandas_df = self._df._query_compiler.getitem_column_array( by ).to_pandas() diff --git a/modin/pandas/test/test_groupby.py b/modin/pandas/test/test_groupby.py index 54c78563357..b116ffadf69 100644 --- a/modin/pandas/test/test_groupby.py +++ b/modin/pandas/test/test_groupby.py @@ -1494,3 +1494,34 @@ def test_multi_column_groupby_different_partitions( by, as_index=as_index ) eval_general(md_grp, pd_grp, func_to_apply) + + +@pytest.mark.parametrize( + "by", + [ + 0, + 1.5, + "str", + pandas.Timestamp("2020-02-02"), + [None], + [0, "str"], + [None, 0], + [pandas.Timestamp("2020-02-02"), 1.5], + ], +) +@pytest.mark.parametrize("as_index", [True, False]) +def test_not_str_by(by, as_index): + data = {f"col{i}": np.arange(5) for i in range(5)} + columns = pandas.Index([0, 1.5, "str", pandas.Timestamp("2020-02-02"), None]) + + md_df, pd_df = create_test_dfs(data, columns=columns) + md_grp, pd_grp = md_df.groupby(by, as_index=as_index), pd_df.groupby( + by, as_index=as_index + ) + + modin_groupby_equals_pandas(md_grp, pd_grp) + df_equals(md_grp.sum(), pd_grp.sum()) + df_equals(md_grp.size(), pd_grp.size()) + df_equals(md_grp.agg(lambda df: df.mean()), pd_grp.agg(lambda df: df.mean())) + df_equals(md_grp.dtypes, pd_grp.dtypes) + df_equals(md_grp.first(), pd_grp.first())