Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: groupby.agg/transform casts UDF results #40790

Merged
merged 24 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
477d813
BUG: groupby.agg/transform downcasts UDF results
rhshadrach Apr 2, 2021
d932c93
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach Apr 10, 2021
f2069a7
Reverted behavior change when input and output are the same kind
rhshadrach Apr 10, 2021
35c789f
Patch via maybe_convert_objects
rhshadrach Apr 10, 2021
93fa089
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach Apr 22, 2021
1cb216e
fixups
rhshadrach Apr 22, 2021
0cafcee
whatsnew
rhshadrach Apr 22, 2021
785ac9d
dtype test fixes
rhshadrach Apr 23, 2021
737a366
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach Apr 23, 2021
0b00aa7
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach Apr 24, 2021
de0f7b5
fixup
rhshadrach Apr 24, 2021
e95bb49
Merge branch 'dont_cast_udfs' of https://github.com/rhshadrach/pandas…
rhshadrach Apr 24, 2021
4ef6794
Fixup
rhshadrach Apr 24, 2021
4f97288
Add GH issue to TODOs
rhshadrach Apr 24, 2021
ad7d990
Added docs to user guide, agg docstring
rhshadrach Apr 25, 2021
11529e3
Updated docs
rhshadrach Apr 25, 2021
0ca49f6
Merge branch 'dont_cast_udfs' of https://github.com/rhshadrach/pandas…
rhshadrach Apr 25, 2021
a0a2640
Fixup
rhshadrach Apr 27, 2021
eb1943a
Fixup
rhshadrach Apr 27, 2021
180bc23
docsting fixup
rhshadrach Apr 29, 2021
47d97ae
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach Apr 29, 2021
4a0978e
Add versionchanged
rhshadrach May 1, 2021
2b38e5c
Merge branch 'master' of https://github.com/pandas-dev/pandas into do…
rhshadrach May 1, 2021
6b80c10
Added versionchanged
rhshadrach May 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ Groupby/resample/rolling
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would not retain ``com``, ``span``, ``alpha`` or ``halflife`` attributes (:issue:`40164`)
- :class:`core.window.ewm.ExponentialMovingWindow` now raises a ``NotImplementedError`` when specifying ``times`` with ``adjust=False`` due to an incorrect calculation (:issue:`40098`)
- Bug in :meth:`Series.asfreq` and :meth:`DataFrame.asfreq` dropping rows when the index is not sorted (:issue:`39805`)
- Bug in :meth:`DataFrameGroupBy.aggregate`, :meth:`SeriesGroupBy.aggregate`, :meth:`DataFrameGroupBy.transform`, and :meth:`SeriesGroupBy.transform` would possibly change the result dtype when ``func`` is callable (:issue:`21240`)
rhshadrach marked this conversation as resolved.
Show resolved Hide resolved

Reshaping
^^^^^^^^^
Expand Down
12 changes: 2 additions & 10 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
)

from pandas.core.dtypes.cast import (
find_common_type,
maybe_cast_result_dtype,
maybe_downcast_numeric,
)
Expand All @@ -61,7 +60,6 @@
is_dict_like,
is_integer_dtype,
is_interval_dtype,
is_numeric_dtype,
is_scalar,
needs_i8_conversion,
)
Expand Down Expand Up @@ -562,8 +560,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

def _transform_general(self, func, *args, **kwargs):
"""
Transform with a non-str `func`.
Transform with a callable func`.
"""
assert callable(func)
klass = type(self._selected_obj)

results = []
Expand All @@ -584,13 +583,6 @@ def _transform_general(self, func, *args, **kwargs):
result = self._set_result_index_ordered(concatenated)
else:
result = self.obj._constructor(dtype=np.float64)
# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* user-defined funcs
# the cython take a different path (and casting)
if is_numeric_dtype(result.dtype):
common_dtype = find_common_type([self._selected_obj.dtype, result.dtype])
if common_dtype is result.dtype:
result = maybe_downcast_numeric(result, self._selected_obj.dtype)

result.name = self._selected_obj.name
return result
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,6 @@ def _python_agg_general(self, func, *args, **kwargs):
assert result is not None
key = base.OutputKey(label=name, position=idx)

if is_numeric_dtype(obj.dtype):
result = maybe_downcast_numeric(result, obj.dtype)

if self.grouper._filter_empty_groups:
mask = counts.ravel() > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhshadrach can i get your help in this nearby piece of code? in all existing tests, when we get here, we have mask.all(). trying to come up with a case where this doesnt hold (or prove that it must always hold). any thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is now removed - guessing mask.all() always held.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep.

next thing to ask your help with : IIRC you've done a lot of work in core.apply, which DataFrameGroupBy.aggregate uses. id like to make SeriesGroupBy.aggregate and DataFrameGroupBy.aggregate share more code (or at least be more obviously-similar). can i get your thoughts on how to achieve this (and whether its the effort)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we make aggregate always aggregate (PoC implemented in #40275), we can greatly simplify these methods. However, in order to do that we need to separate the apply/agg paths: currently apply uses agg for list/dicts and agg also uses apply for UDFs. I make a couple of attempts to do this but kept running into issues with changing behaviors without having a clear way to deprecate. This was the motivation for #41112. I plan to start working on that, assuming that's a good approach, in 1.3.


Expand Down
52 changes: 49 additions & 3 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def test_aggregate_item_by_item(df):

# GH5782
# odd comparisons can result here, so cast to make easy
exp = Series(np.array([foo] * K), index=list("BCD"), dtype=np.float64, name="foo")
exp = Series(np.array([foo] * K), index=list("BCD"), name="foo")
tm.assert_series_equal(result.xs("foo"), exp)

exp = Series(np.array([bar] * K), index=list("BCD"), dtype=np.float64, name="bar")
exp = Series(np.array([bar] * K), index=list("BCD"), name="bar")
tm.assert_almost_equal(result.xs("bar"), exp)

def aggfun(ser):
Expand Down Expand Up @@ -442,6 +442,48 @@ def test_bool_agg_dtype(op):
assert is_integer_dtype(result)


@pytest.mark.parametrize(
"keys, agg_index",
[
(["a"], Index([1], name="a")),
(["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])),
],
)
@pytest.mark.parametrize("input", [True, 1, 1.0])
@pytest.mark.parametrize("dtype", [bool, int, float])
@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"])
def test_callable_result_dtype_frame(keys, agg_index, input, dtype, method):
# GH 21240
df = DataFrame({"a": [1], "b": [2], "c": [input]})
op = getattr(df.groupby(keys)[["c"]], method)
result = op(lambda x: x.astype(dtype).iloc[0])
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
expected = DataFrame({"c": [df["c"].iloc[0]]}, index=expected_index).astype(dtype)
if method == "apply":
expected.columns.names = [0]
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"keys, agg_index",
[
(["a"], Index([1], name="a")),
(["a", "b"], MultiIndex([[1], [2]], [[0], [0]], names=["a", "b"])),
],
)
@pytest.mark.parametrize("input", [True, 1, 1.0])
@pytest.mark.parametrize("dtype", [bool, int, float])
@pytest.mark.parametrize("method", ["apply", "aggregate", "transform"])
def test_callable_result_dtype_series(keys, agg_index, input, dtype, method):
# GH 21240
df = DataFrame({"a": [1], "b": [2], "c": [input]})
op = getattr(df.groupby(keys)["c"], method)
result = op(lambda x: x.astype(dtype).iloc[0])
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
expected = Series([df["c"].iloc[0]], index=expected_index, name="c").astype(dtype)
tm.assert_series_equal(result, expected)


def test_order_aggregate_multiple_funcs():
# GH 25692
df = DataFrame({"A": [1, 1, 2, 2], "B": [1, 2, 3, 4]})
Expand Down Expand Up @@ -849,7 +891,11 @@ def test_multiindex_custom_func(func):
data = [[1, 4, 2], [5, 7, 1]]
df = DataFrame(data, columns=MultiIndex.from_arrays([[1, 1, 2], [3, 4, 3]]))
result = df.groupby(np.array([0, 1])).agg(func)
expected_dict = {(1, 3): {0: 1, 1: 5}, (1, 4): {0: 4, 1: 7}, (2, 3): {0: 2, 1: 1}}
expected_dict = {
(1, 3): {0: 1.0, 1: 5.0},
(1, 4): {0: 4.0, 1: 7.0},
(2, 3): {0: 2.0, 1: 1.0},
}
expected = DataFrame(expected_dict)
tm.assert_frame_equal(result, expected)

Expand Down
3 changes: 3 additions & 0 deletions pandas/tests/groupby/aggregate/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def test_cython_agg_empty_buckets(op, targop, observed):

g = df.groupby(pd.cut(df[0], grps), observed=observed)
expected = g.agg(lambda x: targop(x))
if observed and op not in ("min", "max"):
# TODO: cython_agg_general with mean/var should be float64
expected = expected.astype("int64")
tm.assert_frame_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ def test_aggregate_categorical_with_isnan():
index = MultiIndex.from_arrays([[1, 1], [1, 2]], names=("A", "B"))
expected = DataFrame(
data={
"numerical_col": [1.0, 0.0],
"numerical_col": [1, 0],
"object_col": [0, 0],
"categorical_col": [0, 0],
},
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ def test_median_empty_bins(observed):

result = df.groupby(bins, observed=observed).median()
expected = df.groupby(bins, observed=observed).agg(lambda x: x.median())
if observed:
# TODO: groupby(..).median should be float64
expected = expected.astype("int64")
tm.assert_frame_equal(result, expected)


Expand Down Expand Up @@ -616,7 +619,7 @@ def test_ops_general(op, targop):
df = DataFrame(np.random.randn(1000))
labels = np.random.randint(0, 50, size=1000).astype(float)

result = getattr(df.groupby(labels), op)().astype(float)
result = getattr(df.groupby(labels), op)()
expected = df.groupby(labels).agg(targop)
tm.assert_frame_equal(result, expected)

Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,8 @@ def f(x):
return float(len(x))

agged = grouped.agg(f)
expected = Series([4, 2], index=["bar", "foo"])

tm.assert_series_equal(agged, expected, check_dtype=False)
assert issubclass(agged.dtype.type, np.dtype(dtype).type)
expected = Series([4.0, 2.0], index=["bar", "foo"])
tm.assert_series_equal(agged, expected)


def test_indices_concatenation_order():
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_transform_bug():
# transforming on a datetime column
df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)})
result = df.groupby("A")["B"].transform(lambda x: x.rank(ascending=False))
expected = Series(np.arange(5, 0, step=-1), name="B")
expected = Series(np.arange(5, 0, step=-1), name="B", dtype="float64")
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -493,7 +493,7 @@ def test_groupby_transform_with_int():
)
with np.errstate(all="ignore"):
result = df.groupby("A").transform(lambda x: (x - x.mean()) / x.std())
expected = DataFrame({"B": np.nan, "C": [-1, 0, 1, -1, 0, 1]})
expected = DataFrame({"B": np.nan, "C": [-1.0, 0.0, 1.0, -1.0, 0.0, 1.0]})
tm.assert_frame_equal(result, expected)

# int that needs float conversion
Expand All @@ -509,9 +509,9 @@ def test_groupby_transform_with_int():
expected = DataFrame({"B": np.nan, "C": concat([s1, s2])})
tm.assert_frame_equal(result, expected)

# int downcasting
# int doesn't get downcasted
result = df.groupby("A").transform(lambda x: x * 2 / 2)
expected = DataFrame({"B": 1, "C": [2, 3, 4, 10, 5, -1]})
expected = DataFrame({"B": 1.0, "C": [2.0, 3.0, 4.0, 10.0, 5.0, -1.0]})
tm.assert_frame_equal(result, expected)


Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/resample/test_datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,9 @@ def test_resample_median_bug_1688():

result = df.resample("T").apply(lambda x: x.mean())
exp = df.asfreq("T")
if dtype == "float32":
# TODO: fastpath for apply comes back at float64
exp = exp.astype("float64")
tm.assert_frame_equal(result, exp)

result = df.resample("T").median()
Expand Down Expand Up @@ -1684,6 +1687,8 @@ def f(data, add_arg):
df = DataFrame({"A": 1, "B": 2}, index=date_range("2017", periods=10))
result = df.groupby("A").resample("D").agg(f, multiplier)
expected = df.groupby("A").resample("D").mean().multiply(multiplier)
# TODO: resample(...).mean should be a float64
expected = expected.astype("float64")
tm.assert_frame_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/resample/test_resampler_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_apply_columns_multilevel():
agg_dict = {col: (np.sum if col[3] == "one" else np.mean) for col in df.columns}
result = df.resample("H").apply(lambda x: agg_dict[x.name](x))
expected = DataFrame(
np.array([0] * 4).reshape(2, 2),
2 * [[0, 0.0]],
index=date_range(start="2017-01-01", freq="1H", periods=2),
columns=pd.MultiIndex.from_tuples(
[("A", "a", "", "one"), ("B", "b", "i", "two")]
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/resample/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_resample_with_timedelta_yields_no_empty_groups():
result = df.loc["1s":, :].resample("3s").apply(lambda x: len(x))

expected = DataFrame(
[[768.0] * 4] * 12 + [[528.0] * 4],
[[768] * 4] * 12 + [[528] * 4],
index=timedelta_range(start="1s", periods=13, freq="3s"),
)
tm.assert_frame_equal(result, expected)
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/reshape/test_crosstab.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def test_crosstab_with_numpy_size(self):
expected = DataFrame(
expected_data, index=expected_index, columns=expected_column
)
expected["All"] = expected["All"].astype("int64")
tm.assert_frame_equal(result, expected)

def test_crosstab_duplicate_names(self):
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/reshape/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,6 @@ def test_margins_dtype(self):

tm.assert_frame_equal(expected, result)

@pytest.mark.xfail(reason="GH#17035 (len of floats is casted back to floats)")
def test_margins_dtype_len(self):
mi_val = list(product(["bar", "foo"], ["one", "two"])) + [("All", "")]
mi = MultiIndex.from_tuples(mi_val, names=("A", "B"))
Expand Down