Skip to content

Commit

Permalink
FIX-#3755: return positional indices for '.indices' GroupBy property (#…
Browse files Browse the repository at this point in the history
…3758)

Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev authored Dec 3, 2021
1 parent a4ed727 commit e9c06f2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 72 deletions.
165 changes: 93 additions & 72 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def __init__(
self._squeeze = squeeze
self._kwargs.update(kwargs)

_index_grouped_cache = None

def __getattr__(self, key):
"""
Alter regular attribute access, looks up the name in the columns.
Expand Down Expand Up @@ -164,9 +162,17 @@ def __bytes__(self):
def tshift(self):
return self._default_to_pandas(lambda df: df.tshift)

_groups_cache = no_default

# TODO: since python 3.9:
# @cached_property
@property
def groups(self):
return self._index_grouped
if self._groups_cache is not no_default:
return self._groups_cache

self._groups_cache = self._compute_index_grouped(numerical=False)
return self._groups_cache

def min(self, **kwargs):
return self._wrap_aggregation(
Expand Down Expand Up @@ -256,9 +262,17 @@ def cumsum(self, axis=0, *args, **kwargs):
lambda df: df.cumsum(axis, *args, **kwargs)
)

_indices_cache = no_default

# TODO: since python 3.9:
# @cached_property
@property
def indices(self):
return self._index_grouped
if self._indices_cache is not no_default:
return self._indices_cache

self._indices_cache = self._compute_index_grouped(numerical=True)
return self._indices_cache

def pct_change(self):
return self._default_to_pandas(lambda df: df.pct_change())
Expand Down Expand Up @@ -554,7 +568,7 @@ def get_group(self, name, obj=None):
return self._default_to_pandas(lambda df: df.get_group(name, obj=obj))

def __len__(self):
return len(self._index_grouped)
return len(self.indices)

def all(self, **kwargs):
return self._wrap_aggregation(
Expand Down Expand Up @@ -787,14 +801,15 @@ def _iter(self):
"""
from .dataframe import DataFrame

group_ids = self._index_grouped.keys()
indices = self.indices
group_ids = indices.keys()
if self._axis == 0:
return (
(
k,
DataFrame(
query_compiler=self._query_compiler.getitem_row_array(
self._index.get_indexer_for(self._index_grouped[k].unique())
indices[k]
)
),
)
Expand All @@ -806,89 +821,94 @@ def _iter(self):
k,
DataFrame(
query_compiler=self._query_compiler.getitem_column_array(
self._index_grouped[k].unique()
indices[k], numeric=True
)
),
)
for k in (sorted(group_ids) if self._sort else group_ids)
)

@property
def _index_grouped(self):
def _compute_index_grouped(self, numerical=False):
"""
Construct an index of group IDs.
Parameters
----------
numerical : bool, default: False
Whether a group indices should be positional (True) or label-based (False).
Returns
-------
dict
A dict of {group name -> group labels} values.
A dict of {group name -> group indices} values.
See Also
--------
pandas.core.groupby.GroupBy.groups
"""
if self._index_grouped_cache is None:
# Splitting level-by and column-by since we serialize them in a different ways
by = None
level = []
if self._level is not None:
level = self._level
if not isinstance(level, list):
level = [level]
elif isinstance(self._by, list):
by = []
for o in self._by:
if hashable(o) and o in self._query_compiler.get_index_names(
self._axis
):
level.append(o)
else:
by.append(o)
else:
by = self._by
# Splitting level-by and column-by since we serialize them in a different ways
by = None
level = []
if self._level is not None:
level = self._level
if not isinstance(level, list):
level = [level]
elif isinstance(self._by, list):
by = []
for o in self._by:
if hashable(o) and o in self._query_compiler.get_index_names(
self._axis
):
level.append(o)
else:
by.append(o)
else:
by = self._by

is_multi_by = self._is_multi_by or (by is not None and len(level) > 0)
is_multi_by = self._is_multi_by or (by is not None and len(level) > 0)

if hasattr(self._by, "columns") and is_multi_by:
by = list(self._by.columns)
if hasattr(self._by, "columns") and is_multi_by:
by = list(self._by.columns)

if is_multi_by:
# Because we are doing a collect (to_pandas) here and then groupby, we
# end up using pandas implementation. Add the warning so the user is
# aware.
ErrorMessage.catch_bugs_and_request_email(self._axis == 1)
ErrorMessage.default_to_pandas("Groupby with multiple columns")
if isinstance(by, list) and all(
is_label(self._df, o, self._axis) for o in by
):
pandas_df = self._df._query_compiler.getitem_column_array(
by
).to_pandas()
else:
by = try_cast_to_pandas(by, squeeze=True)
pandas_df = self._df._to_pandas()
by = wrap_into_list(by, level)
self._index_grouped_cache = pandas_df.groupby(by=by).groups
if is_multi_by:
# Because we are doing a collect (to_pandas) here and then groupby, we
# end up using pandas implementation. Add the warning so the user is
# aware.
ErrorMessage.catch_bugs_and_request_email(self._axis == 1)
ErrorMessage.default_to_pandas("Groupby with multiple columns")
if isinstance(by, list) and all(
is_label(self._df, o, self._axis) for o in by
):
pandas_df = self._df._query_compiler.getitem_column_array(
by
).to_pandas()
else:
if isinstance(self._by, type(self._query_compiler)):
by = self._by.to_pandas().squeeze().values
elif self._by is None:
index = self._query_compiler.get_axis(self._axis)
levels_to_drop = [
i
for i, name in enumerate(index.names)
if name not in level and i not in level
]
by = index.droplevel(levels_to_drop)
if isinstance(by, pandas.MultiIndex):
by = by.reorder_levels(level)
else:
by = self._by
if self._axis == 0:
self._index_grouped_cache = self._index.groupby(by)
else:
self._index_grouped_cache = self._columns.groupby(by)
return self._index_grouped_cache
by = try_cast_to_pandas(by, squeeze=True)
pandas_df = self._df._to_pandas()
by = wrap_into_list(by, level)
groupby_obj = pandas_df.groupby(by=by)
return groupby_obj.indices if numerical else groupby_obj.groups
else:
if isinstance(self._by, type(self._query_compiler)):
by = self._by.to_pandas().squeeze().values
elif self._by is None:
index = self._query_compiler.get_axis(self._axis)
levels_to_drop = [
i
for i, name in enumerate(index.names)
if name not in level and i not in level
]
by = index.droplevel(levels_to_drop)
if isinstance(by, pandas.MultiIndex):
by = by.reorder_levels(level)
else:
by = self._by
axis_labels = self._query_compiler.get_axis(self._axis)
if numerical:
# Since we want positional indices of the groups, we want to group
# on a `RangeIndex`, not on the actual index labels
axis_labels = pandas.RangeIndex(len(axis_labels))
return axis_labels.groupby(by)

def _wrap_aggregation(
self, qc_method, default_func, drop=True, numeric_only=True, **kwargs
Expand Down Expand Up @@ -1107,14 +1127,15 @@ def _iter(self):
generator
Generator expression of GroupBy object broken down into tuples for iteration.
"""
group_ids = self._index_grouped.keys()
indices = self.indices
group_ids = indices.keys()
if self._axis == 0:
return (
(
k,
Series(
query_compiler=self._query_compiler.getitem_row_array(
self._index.get_indexer_for(self._index_grouped[k].unique())
indices[k]
)
),
)
Expand All @@ -1126,7 +1147,7 @@ def _iter(self):
k,
Series(
query_compiler=self._query_compiler.getitem_column_array(
self._index_grouped[k].unique()
indices[k], numeric=True
)
),
)
Expand Down
8 changes: 8 additions & 0 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,21 @@
modin_df_almost_equals_pandas,
generate_multiindex,
test_groupby_data,
dict_equals,
)
from modin.config import NPartitions

NPartitions.put(4)


def modin_groupby_equals_pandas(modin_groupby, pandas_groupby):
eval_general(
modin_groupby, pandas_groupby, lambda grp: grp.indices, comparator=dict_equals
)
eval_general(
modin_groupby, pandas_groupby, lambda grp: grp.groups, comparator=dict_equals
)

for g1, g2 in itertools.zip_longest(modin_groupby, pandas_groupby):
assert g1[0] == g2[0]
df_equals(g1[1], g2[1])
Expand Down
13 changes: 13 additions & 0 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import numpy as np
import math
import pandas
import itertools
from pandas.testing import (
assert_series_equal,
assert_frame_equal,
assert_index_equal,
assert_extension_array_equal,
)
from pandas.core.dtypes.common import is_list_like
from modin.config.envvars import NPartitions
import modin.pandas as pd
from modin.utils import to_pandas, try_cast_to_pandas
Expand Down Expand Up @@ -1349,3 +1351,14 @@ def _make_default_file(filename=None, nrows=NROWS, ncols=2, force=True, **kwargs
return filename

return _make_default_file, filenames


def dict_equals(dict1, dict2):
"""Check whether two dictionaries are equal and raise an ``AssertionError`` if they aren't."""
for key1, key2 in itertools.zip_longest(sorted(dict1), sorted(dict2)):
assert (key1 == key2) or (np.isnan(key1) and np.isnan(key2))
value1, value2 = dict1[key1], dict2[key2]
if is_list_like(value1):
np.testing.assert_array_equal(value1, value2)
else:
assert value1 == value2

0 comments on commit e9c06f2

Please sign in to comment.