diff --git a/modin/pandas/groupby.py b/modin/pandas/groupby.py index 4d2708d0853..9a04f8f7636 100644 --- a/modin/pandas/groupby.py +++ b/modin/pandas/groupby.py @@ -92,8 +92,6 @@ def __init__( self._squeeze = squeeze self._kwargs.update(kwargs) - _index_grouped_cache = None - def __getattr__(self, key): """ Alter regular attribute access, looks up the name in the columns. @@ -164,9 +162,17 @@ def __bytes__(self): def tshift(self): return self._default_to_pandas(lambda df: df.tshift) + _groups_cache = no_default + + # TODO: since python 3.9: + # @cached_property @property def groups(self): - return self._index_grouped + if self._groups_cache is not no_default: + return self._groups_cache + + self._groups_cache = self._compute_index_grouped(numerical=False) + return self._groups_cache def min(self, **kwargs): return self._wrap_aggregation( @@ -256,9 +262,17 @@ def cumsum(self, axis=0, *args, **kwargs): lambda df: df.cumsum(axis, *args, **kwargs) ) + _indices_cache = no_default + + # TODO: since python 3.9: + # @cached_property @property def indices(self): - return self._index_grouped + if self._indices_cache is not no_default: + return self._indices_cache + + self._indices_cache = self._compute_index_grouped(numerical=True) + return self._indices_cache def pct_change(self): return self._default_to_pandas(lambda df: df.pct_change()) @@ -554,7 +568,7 @@ def get_group(self, name, obj=None): return self._default_to_pandas(lambda df: df.get_group(name, obj=obj)) def __len__(self): - return len(self._index_grouped) + return len(self.indices) def all(self, **kwargs): return self._wrap_aggregation( @@ -787,14 +801,15 @@ def _iter(self): """ from .dataframe import DataFrame - group_ids = self._index_grouped.keys() + indices = self.indices + group_ids = indices.keys() if self._axis == 0: return ( ( k, DataFrame( query_compiler=self._query_compiler.getitem_row_array( - self._index.get_indexer_for(self._index_grouped[k].unique()) + indices[k] ) ), ) @@ -806,89 +821,94 @@ def _iter(self): k, DataFrame( query_compiler=self._query_compiler.getitem_column_array( - self._index_grouped[k].unique() + indices[k], numeric=True ) ), ) for k in (sorted(group_ids) if self._sort else group_ids) ) - @property - def _index_grouped(self): + def _compute_index_grouped(self, numerical=False): """ Construct an index of group IDs. + Parameters + ---------- + numerical : bool, default: False + Whether a group indices should be positional (True) or label-based (False). + Returns ------- dict - A dict of {group name -> group labels} values. + A dict of {group name -> group indices} values. See Also -------- pandas.core.groupby.GroupBy.groups """ - if self._index_grouped_cache is None: - # Splitting level-by and column-by since we serialize them in a different ways - by = None - level = [] - if self._level is not None: - level = self._level - if not isinstance(level, list): - level = [level] - elif isinstance(self._by, list): - by = [] - for o in self._by: - if hashable(o) and o in self._query_compiler.get_index_names( - self._axis - ): - level.append(o) - else: - by.append(o) - else: - by = self._by + # Splitting level-by and column-by since we serialize them in a different ways + by = None + level = [] + if self._level is not None: + level = self._level + if not isinstance(level, list): + level = [level] + elif isinstance(self._by, list): + by = [] + for o in self._by: + if hashable(o) and o in self._query_compiler.get_index_names( + self._axis + ): + level.append(o) + else: + by.append(o) + else: + by = self._by - is_multi_by = self._is_multi_by or (by is not None and len(level) > 0) + is_multi_by = self._is_multi_by or (by is not None and len(level) > 0) - if hasattr(self._by, "columns") and is_multi_by: - by = list(self._by.columns) + if hasattr(self._by, "columns") and is_multi_by: + by = list(self._by.columns) - if is_multi_by: - # Because we are doing a collect (to_pandas) here and then groupby, we - # end up using pandas implementation. Add the warning so the user is - # aware. - ErrorMessage.catch_bugs_and_request_email(self._axis == 1) - ErrorMessage.default_to_pandas("Groupby with multiple columns") - if isinstance(by, list) and all( - is_label(self._df, o, self._axis) for o in by - ): - pandas_df = self._df._query_compiler.getitem_column_array( - by - ).to_pandas() - else: - by = try_cast_to_pandas(by, squeeze=True) - pandas_df = self._df._to_pandas() - by = wrap_into_list(by, level) - self._index_grouped_cache = pandas_df.groupby(by=by).groups + if is_multi_by: + # Because we are doing a collect (to_pandas) here and then groupby, we + # end up using pandas implementation. Add the warning so the user is + # aware. + ErrorMessage.catch_bugs_and_request_email(self._axis == 1) + ErrorMessage.default_to_pandas("Groupby with multiple columns") + if isinstance(by, list) and all( + is_label(self._df, o, self._axis) for o in by + ): + pandas_df = self._df._query_compiler.getitem_column_array( + by + ).to_pandas() else: - if isinstance(self._by, type(self._query_compiler)): - by = self._by.to_pandas().squeeze().values - elif self._by is None: - index = self._query_compiler.get_axis(self._axis) - levels_to_drop = [ - i - for i, name in enumerate(index.names) - if name not in level and i not in level - ] - by = index.droplevel(levels_to_drop) - if isinstance(by, pandas.MultiIndex): - by = by.reorder_levels(level) - else: - by = self._by - if self._axis == 0: - self._index_grouped_cache = self._index.groupby(by) - else: - self._index_grouped_cache = self._columns.groupby(by) - return self._index_grouped_cache + by = try_cast_to_pandas(by, squeeze=True) + pandas_df = self._df._to_pandas() + by = wrap_into_list(by, level) + groupby_obj = pandas_df.groupby(by=by) + return groupby_obj.indices if numerical else groupby_obj.groups + else: + if isinstance(self._by, type(self._query_compiler)): + by = self._by.to_pandas().squeeze().values + elif self._by is None: + index = self._query_compiler.get_axis(self._axis) + levels_to_drop = [ + i + for i, name in enumerate(index.names) + if name not in level and i not in level + ] + by = index.droplevel(levels_to_drop) + if isinstance(by, pandas.MultiIndex): + by = by.reorder_levels(level) + else: + by = self._by + axis_labels = self._query_compiler.get_axis(self._axis) + if numerical: + # Since we want positional indices of the groups, we want to group + # on a `RangeIndex`, not on the actual index labels + axis_labels = pandas.RangeIndex(len(axis_labels)) + return axis_labels.groupby(by) def _wrap_aggregation( self, qc_method, default_func, drop=True, numeric_only=True, **kwargs @@ -1107,14 +1127,15 @@ def _iter(self): generator Generator expression of GroupBy object broken down into tuples for iteration. """ - group_ids = self._index_grouped.keys() + indices = self.indices + group_ids = indices.keys() if self._axis == 0: return ( ( k, Series( query_compiler=self._query_compiler.getitem_row_array( - self._index.get_indexer_for(self._index_grouped[k].unique()) + indices[k] ) ), ) @@ -1126,7 +1147,7 @@ def _iter(self): k, Series( query_compiler=self._query_compiler.getitem_column_array( - self._index_grouped[k].unique() + indices[k], numeric=True ) ), ) diff --git a/modin/pandas/test/test_groupby.py b/modin/pandas/test/test_groupby.py index 5df5dde1afd..c015d306dba 100644 --- a/modin/pandas/test/test_groupby.py +++ b/modin/pandas/test/test_groupby.py @@ -29,6 +29,7 @@ modin_df_almost_equals_pandas, generate_multiindex, test_groupby_data, + dict_equals, ) from modin.config import NPartitions @@ -36,6 +37,13 @@ def modin_groupby_equals_pandas(modin_groupby, pandas_groupby): + eval_general( + modin_groupby, pandas_groupby, lambda grp: grp.indices, comparator=dict_equals + ) + eval_general( + modin_groupby, pandas_groupby, lambda grp: grp.groups, comparator=dict_equals + ) + for g1, g2 in itertools.zip_longest(modin_groupby, pandas_groupby): assert g1[0] == g2[0] df_equals(g1[1], g2[1]) diff --git a/modin/pandas/test/utils.py b/modin/pandas/test/utils.py index 511e66b1708..6d3734a20ab 100644 --- a/modin/pandas/test/utils.py +++ b/modin/pandas/test/utils.py @@ -15,12 +15,14 @@ import numpy as np import math import pandas +import itertools from pandas.testing import ( assert_series_equal, assert_frame_equal, assert_index_equal, assert_extension_array_equal, ) +from pandas.core.dtypes.common import is_list_like from modin.config.envvars import NPartitions import modin.pandas as pd from modin.utils import to_pandas, try_cast_to_pandas @@ -1349,3 +1351,14 @@ def _make_default_file(filename=None, nrows=NROWS, ncols=2, force=True, **kwargs return filename return _make_default_file, filenames + + +def dict_equals(dict1, dict2): + """Check whether two dictionaries are equal and raise an ``AssertionError`` if they aren't.""" + for key1, key2 in itertools.zip_longest(sorted(dict1), sorted(dict2)): + assert (key1 == key2) or (np.isnan(key1) and np.isnan(key2)) + value1, value2 = dict1[key1], dict2[key2] + if is_list_like(value1): + np.testing.assert_array_equal(value1, value2) + else: + assert value1 == value2