Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX-#2482: improved handling non-str 'by' #2548

Merged
merged 1 commit into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modin/backends/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from modin.backends.base.query_compiler import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.utils import try_cast_to_pandas, wrap_udf_function
from modin.utils import try_cast_to_pandas, wrap_udf_function, hashable
from modin.data_management.functions import (
Function,
FoldFunction,
Expand Down Expand Up @@ -2555,7 +2555,7 @@ def is_reduce_fn(fn, deep_level=0):
else:
if not isinstance(by, list):
by = [by]
internal_by = [o for o in by if isinstance(o, str) and o in self.columns]
internal_by = [o for o in by if hashable(o) and o in self.columns]
internal_qc = (
[self.getitem_column_array(internal_by)] if len(internal_by) else []
)
Expand Down
6 changes: 4 additions & 2 deletions modin/data_management/functions/groupby_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas

from .mapreducefunction import MapReduceFunction
from modin.utils import try_cast_to_pandas
from modin.utils import try_cast_to_pandas, hashable


class GroupbyReduceFunction(MapReduceFunction):
Expand Down Expand Up @@ -113,7 +113,9 @@ def caller(
numeric_only=True,
**kwargs,
):
if not isinstance(by, (type(query_compiler), str)):
if not (isinstance(by, (type(query_compiler)) or hashable(by))) or isinstance(
by, pandas.Grouper
):
by = try_cast_to_pandas(by, squeeze=True)
default_func = (
(lambda grp: grp.agg(map_func))
Expand Down
12 changes: 6 additions & 6 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def groupby(

if callable(by):
by = self.index.map(by)
elif isinstance(by, str):
elif hashable(by) and not isinstance(by, pandas.Grouper):
drop = by in self.columns
idx_name = by
if self._query_compiler.has_multiindex(
Expand All @@ -374,7 +374,7 @@ def groupby(
# In this case we pass the string value of the name through to the
# partitions. This is more efficient than broadcasting the values.
pass
else:
elif level is None:
by = self.__getitem__(by)._query_compiler
elif isinstance(by, Series):
drop = by._parent is self
Expand All @@ -384,7 +384,7 @@ def groupby(
# fastpath for multi column groupby
if axis == 0 and all(
(
(isinstance(o, str) and (o in self))
(hashable(o) and (o in self))
or isinstance(o, Series)
or (is_list_like(o) and len(o) == len(self.axes[axis]))
)
Expand All @@ -395,7 +395,7 @@ def groupby(
internal_by, external_by = [], []

for current_by in by:
if isinstance(current_by, str):
if hashable(current_by):
internal_by.append(current_by)
elif isinstance(current_by, Series):
if current_by._parent is self:
Expand All @@ -414,7 +414,7 @@ def groupby(
else:
mismatch = len(by) != len(self.axes[axis])
if mismatch and all(
isinstance(obj, str)
hashable(obj)
and (
obj in self or obj in self._query_compiler.get_index_names(axis)
)
Expand All @@ -424,7 +424,7 @@ def groupby(
# we default to pandas in this case.
pass
elif mismatch and any(
isinstance(obj, str) and obj not in self.columns for obj in by
hashable(obj) and obj not in self.columns for obj in by
):
names = [o.name if isinstance(o, Series) else o for o in by]
raise KeyError(next(x for x in names if x not in self))
Expand Down
20 changes: 16 additions & 4 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from collections.abc import Iterable

from modin.error_message import ErrorMessage
from modin.utils import _inherit_docstrings, try_cast_to_pandas, wrap_udf_function
from modin.utils import (
_inherit_docstrings,
try_cast_to_pandas,
wrap_udf_function,
hashable,
)
from modin.backends.base.query_compiler import BaseQueryCompiler
from modin.config import IsExperimental
from .series import Series
Expand Down Expand Up @@ -79,7 +84,7 @@ def __init__(
not isinstance(by, type(self._query_compiler))
and axis == 0
and all(
(isinstance(obj, str) and obj in self._query_compiler.columns)
(hashable(obj) and obj in self._query_compiler.columns)
or isinstance(obj, type(self._query_compiler))
or is_list_like(obj)
for obj in self._by
Expand Down Expand Up @@ -324,7 +329,7 @@ def __getitem__(self, key):
if (
self._is_multi_by
and isinstance(self._by, list)
and not all(isinstance(o, str) for o in self._by)
and not all(hashable(o) and o in self._df for o in self._by)
):
raise NotImplementedError(
"Column lookups on GroupBy with arbitrary Series in by"
Expand Down Expand Up @@ -809,7 +814,14 @@ def _index_grouped(self):
# aware.
ErrorMessage.catch_bugs_and_request_email(self._axis == 1)
ErrorMessage.default_to_pandas("Groupby with multiple columns")
if isinstance(by, list) and all(isinstance(o, str) for o in by):
if isinstance(by, list) and all(
hashable(o)
and (
o in self._df
or o in self._df._query_compiler.get_index_names(self._axis)
)
for o in by
):
pandas_df = self._df._query_compiler.getitem_column_array(
by
).to_pandas()
Expand Down
31 changes: 31 additions & 0 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,3 +1494,34 @@ def test_multi_column_groupby_different_partitions(
by, as_index=as_index
)
eval_general(md_grp, pd_grp, func_to_apply)


@pytest.mark.parametrize(
"by",
[
0,
1.5,
"str",
pandas.Timestamp("2020-02-02"),
[None],
[0, "str"],
[None, 0],
[pandas.Timestamp("2020-02-02"), 1.5],
],
)
@pytest.mark.parametrize("as_index", [True, False])
def test_not_str_by(by, as_index):
data = {f"col{i}": np.arange(5) for i in range(5)}
columns = pandas.Index([0, 1.5, "str", pandas.Timestamp("2020-02-02"), None])

md_df, pd_df = create_test_dfs(data, columns=columns)
md_grp, pd_grp = md_df.groupby(by, as_index=as_index), pd_df.groupby(
by, as_index=as_index
)

modin_groupby_equals_pandas(md_grp, pd_grp)
df_equals(md_grp.sum(), pd_grp.sum())
df_equals(md_grp.size(), pd_grp.size())
df_equals(md_grp.agg(lambda df: df.mean()), pd_grp.agg(lambda df: df.mean()))
df_equals(md_grp.dtypes, pd_grp.dtypes)
df_equals(md_grp.first(), pd_grp.first())