Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: aggregate always aggregates #40275

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,9 @@ def agg(self):
result = result.T if result is not None else result

if result is None:
result = self.obj.apply(self.orig_f, axis, args=self.args, **self.kwargs)
results, res_index = self.apply_series_generator()
result = self.obj._constructor_sliced(results)
result.index = res_index

return result

Expand Down Expand Up @@ -1018,10 +1020,7 @@ def agg(self):
# we cannot FIRST try the vectorized evaluation, because
# then .agg and .apply would have different semantics if the
# operation is actually defined on the Series, e.g. str
try:
result = self.obj.apply(f, *args, **kwargs)
except (ValueError, AttributeError, TypeError):
result = f(self.obj, *args, **kwargs)
result = f(self.obj, *args, **kwargs)

return result

Expand Down
3 changes: 1 addition & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10285,9 +10285,8 @@ def _agg_by_level(self, name, axis=0, level=0, skipna=True, **kwargs):
grouped = self.groupby(level=level, axis=axis, sort=False)
if hasattr(grouped, name) and skipna:
return getattr(grouped, name)(**kwargs)
axis = self._get_axis_number(axis)
method = getattr(type(self), name)
applyf = lambda x: method(x, axis=axis, skipna=skipna, **kwargs)
applyf = lambda x: method(x, skipna=skipna, **kwargs)
return grouped.aggregate(applyf)

@final
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
if result is None:

# grouper specific aggregations
if self.grouper.nkeys > 1:
if not self._obj_with_exclusions.empty or self.grouper.nkeys > 1:
return self._python_agg_general(func, *args, **kwargs)
elif args or kwargs:
result = self._aggregate_frame(func, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def _agg_general(

# apply a non-cython aggregation
if result is None:
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
result = self.aggregate(lambda x: npfunc(x))
return result.__finalize__(self.obj, method="groupby")

def _cython_agg_general(
Expand Down
11 changes: 1 addition & 10 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,24 +760,15 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):

counts = np.zeros(ngroups, dtype=int)
result = np.empty(ngroups, dtype="O")
initialized = False

splitter = get_splitter(obj, group_index, ngroups, axis=0)

for label, group in splitter:

# Each step of this loop corresponds to
# libreduction._BaseGrouper._apply_to_group
res = func(group)
res = libreduction.extract_result(res)

if not initialized:
# We only do this validation on the first iteration
libreduction.check_result_array(res, 0)
initialized = True

counts[label] = group.shape[0]
result[label] = res
result[label] = func(group)

result = lib.maybe_convert_objects(result, try_float=False)
result = maybe_cast_result(result, obj, numeric_only=True)
Expand Down
19 changes: 15 additions & 4 deletions pandas/tests/apply/test_series_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,27 @@ def test_demo():
tm.assert_series_equal(result, expected)


def test_agg_apply_evaluate_lambdas_the_same(string_series):
def test_agg_apply_evaluate_lambdas(string_series):
# test that we are evaluating row-by-row first
# before vectorized evaluation
expected = string_series.astype(str)

result = string_series.apply(lambda x: str(x))
expected = string_series.agg(lambda x: str(x))
tm.assert_series_equal(result, expected)

result = string_series.apply(str)
expected = string_series.agg(str)
tm.assert_series_equal(result, expected)

# GH 35725
# Agg always aggs - applies the function to the entire Series
expected = str(string_series)

result = string_series.agg(lambda x: str(x))
assert result == expected

result = string_series.agg(str)
assert result == expected


def test_with_nested_series(datetime_series):
# GH 2316
Expand All @@ -318,7 +328,8 @@ def test_with_nested_series(datetime_series):
tm.assert_frame_equal(result, expected)

result = datetime_series.agg(lambda x: Series([x, x ** 2], index=["x", "x^2"]))
tm.assert_frame_equal(result, expected)
expected = Series([datetime_series, datetime_series ** 2], index=["x", "x^2"])
tm.assert_series_equal(result, expected)


def test_replicate_describe(string_series):
Expand Down
13 changes: 5 additions & 8 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ def test_agg_regression1(tsframe):

def test_agg_must_agg(df):
grouped = df.groupby("A")["C"]

msg = "Must produce aggregated value"
with pytest.raises(Exception, match=msg):
grouped.agg(lambda x: x.describe())
with pytest.raises(Exception, match=msg):
grouped.agg(lambda x: x.index[:2])
result = grouped.agg(lambda x: x.describe())
expected = Series({name: group.describe() for name, group in grouped}, name="C")
expected.index.name = "A"
tm.assert_series_equal(result, expected)


def test_agg_ser_multi_key(df):
Expand Down Expand Up @@ -127,9 +125,8 @@ def test_groupby_aggregation_multi_level_column():
data=lst,
columns=MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 0), ("B", 1)]),
)

result = df.groupby(level=1, axis=1).sum()
expected = DataFrame({0: [2.0, 1, 1, 1], 1: [1, 0, 1, 1]})
expected = DataFrame({0: [2, 1, 1, 1], 1: [1, 0, 1, 1]})

tm.assert_frame_equal(result, expected)

Expand Down
3 changes: 1 addition & 2 deletions pandas/tests/groupby/aggregate/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,8 @@ def test_agg_lambda_with_timezone():
)
result = df.groupby("tag").agg({"date": lambda e: e.head(1)})
expected = DataFrame(
[pd.Timestamp("2018-01-01", tz="UTC")],
{"date": [df["date"].iloc[:1]]},
index=Index([1], name="tag"),
columns=["date"],
)
tm.assert_frame_equal(result, expected)

Expand Down
10 changes: 4 additions & 6 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ def test_basic(dtype):
agged = grouped.agg(lambda x: group_constants[x.name] + x.mean())
assert agged[1] == 21

# corner cases
msg = "Must produce aggregated value"
# exception raised is type Exception
with pytest.raises(Exception, match=msg):
grouped.aggregate(lambda x: x * 2)
result = grouped.aggregate(lambda x: x * 2)
expected = Series({name: group * 2 for name, group in grouped})
tm.assert_series_equal(result, expected)


def test_groupby_nonobject_dtype(mframe, df_mixed_floats):
Expand Down Expand Up @@ -1026,7 +1024,7 @@ def test_groupby_with_hier_columns():
result = df.groupby(level=0).apply(lambda x: x.mean())
tm.assert_index_equal(result.columns, columns)

result = df.groupby(level=0, axis=1).agg(lambda x: x.mean(1))
result = df.groupby(level=0, axis=1).agg(lambda x: x.mean())
tm.assert_index_equal(result.columns, Index(["A", "B"]))
tm.assert_index_equal(result.index, df.index)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_frame_group_ops(

def aggf(x):
pieces.append(x)
return getattr(x, op)(skipna=skipna, axis=axis)
return getattr(x, op)(skipna=skipna)

leftside = grouped.agg(aggf)
rightside = getattr(frame, op)(level=level, axis=axis, skipna=skipna)
Expand Down