Skip to content

Commit

Permalink
BUG: groupby.agg/transform casts UDF results (pandas-dev#40790)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhshadrach authored and yeshsurya committed May 6, 2021
1 parent d9b5d0b commit e9fd2cf
Show file tree
Hide file tree
Showing 19 changed files with 221 additions and 57 deletions.
2 changes: 1 addition & 1 deletion doc/source/user_guide/gotchas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ To test for membership in the values, use the method :meth:`~pandas.Series.isin`
For ``DataFrames``, likewise, ``in`` applies to the column axis,
testing for membership in the list of column names.

.. _udf-mutation:
.. _gotchas.udf-mutation:

Mutating with User Defined Function (UDF) methods
-------------------------------------------------
Expand Down
31 changes: 29 additions & 2 deletions doc/source/user_guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,26 @@ optimized Cython implementations:
Of course ``sum`` and ``mean`` are implemented on pandas objects, so the above
code would work even without the special versions via dispatching (see below).

.. _groupby.aggregate.udfs:

Aggregations with User-Defined Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Users can also provide their own functions for custom aggregations. When aggregating
with a User-Defined Function (UDF), the UDF should not mutate the provided ``Series``, see
:ref:`gotchas.udf-mutation` for more information.

.. ipython:: python
animals.groupby("kind")[["height"]].agg(lambda x: set(x))
The resulting dtype will reflect that of the aggregating function. If the results from different groups have
different dtypes, then a common dtype will be determined in the same way as ``DataFrame`` construction.

.. ipython:: python
animals.groupby("kind")[["height"]].agg(lambda x: x.astype(int).sum())
.. _groupby.transform:

Transformation
Expand All @@ -759,7 +779,11 @@ as the one being grouped. The transform function must:
* (Optionally) operates on the entire group chunk. If this is supported, a
fast path is used starting from the *second* chunk.

For example, suppose we wished to standardize the data within each group:
Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
transformation function. If the results from different groups have different dtypes, then
a common dtype will be determined in the same way as ``DataFrame`` construction.

Suppose we wished to standardize the data within each group:

.. ipython:: python
Expand Down Expand Up @@ -1065,13 +1089,16 @@ that is itself a series, and possibly upcast the result to a DataFrame:
s
s.apply(f)
.. note::

``apply`` can act as a reducer, transformer, *or* filter function, depending on exactly what is passed to it.
So depending on the path taken, and exactly what you are grouping. Thus the grouped columns(s) may be included in
the output as well as set the indices.

Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
apply function. If the results from different groups have different dtypes, then
a common dtype will be determined in the same way as ``DataFrame`` construction.


Numba Accelerated Routines
--------------------------
Expand Down
30 changes: 30 additions & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,36 @@ Preserve dtypes in :meth:`~pandas.DataFrame.combine_first`
combined.dtypes
Group by methods agg and transform no longer changes return dtype for callables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Previously the methods :meth:`.DataFrameGroupBy.aggregate`,
:meth:`.SeriesGroupBy.aggregate`, :meth:`.DataFrameGroupBy.transform`, and
:meth:`.SeriesGroupBy.transform` might cast the result dtype when the argument ``func``
is callable, possibly leading to undesirable results (:issue:`21240`). The cast would
occur if the result is numeric and casting back to the input dtype does not change any
values as measured by ``np.allclose``. Now no such casting occurs.

.. ipython:: python
df = pd.DataFrame({'key': [1, 1], 'a': [True, False], 'b': [True, True]})
df
*pandas 1.2.x*

.. code-block:: ipython
In [5]: df.groupby('key').agg(lambda x: x.sum())
Out[5]:
a b
key
1 True 2
*pandas 1.3.0*

.. ipython:: python
df.groupby('key').agg(lambda x: x.sum())
Try operating inplace when setting values with ``loc`` and ``iloc``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8552,7 +8552,7 @@ def apply(
Notes
-----
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
Examples
Expand Down
41 changes: 25 additions & 16 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@
doc,
)

from pandas.core.dtypes.cast import (
find_common_type,
maybe_downcast_numeric,
)
from pandas.core.dtypes.common import (
ensure_int64,
is_bool,
Expand Down Expand Up @@ -226,7 +222,16 @@ def _selection_name(self):
... )
minimum maximum
1 1 2
2 3 4"""
2 3 4
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the aggregating function.
>>> s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min())
1 1.0
2 3.0
dtype: float64"""
)

@Appender(
Expand Down Expand Up @@ -566,8 +571,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

def _transform_general(self, func, *args, **kwargs):
"""
Transform with a non-str `func`.
Transform with a callable func`.
"""
assert callable(func)
klass = type(self._selected_obj)

results = []
Expand All @@ -589,13 +595,6 @@ def _transform_general(self, func, *args, **kwargs):
result = self._set_result_index_ordered(concatenated)
else:
result = self.obj._constructor(dtype=np.float64)
# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* user-defined funcs
# the cython take a different path (and casting)
if is_numeric_dtype(result.dtype):
common_dtype = find_common_type([self._selected_obj.dtype, result.dtype])
if common_dtype is result.dtype:
result = maybe_downcast_numeric(result, self._selected_obj.dtype)

result.name = self._selected_obj.name
return result
Expand Down Expand Up @@ -625,7 +624,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
Notes
-----
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
Examples
Expand Down Expand Up @@ -1006,7 +1005,17 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
``['column', 'aggfunc']`` to make it clearer what the arguments are.
As usual, the aggregation can be a callable or a string alias.
See :ref:`groupby.aggregate.named` for more."""
See :ref:`groupby.aggregate.named` for more.
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the aggregating function.
>>> df.groupby("A")[["B"]].agg(lambda x: x.astype(float).min())
B
A
1 1.0
2 3.0"""
)

@doc(_agg_template, examples=_agg_examples_doc, klass="DataFrame")
Expand Down Expand Up @@ -1533,7 +1542,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
which group you are working on.
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
Examples
Expand Down
67 changes: 50 additions & 17 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,19 @@ class providing the base-class of operations.
side-effects, as they will take effect twice for the first
group.
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the passed ``func``,
see the examples below.
Examples
--------
{examples}
""",
"dataframe_examples": """
>>> df = pd.DataFrame({'A': 'a a b'.split(),
... 'B': [1,2,3],
... 'C': [4,6, 5]})
... 'C': [4,6,5]})
>>> g = df.groupby('A')
Notice that ``g`` has two groups, ``a`` and ``b``.
Expand All @@ -183,13 +188,17 @@ class providing the base-class of operations.
Example 2: The function passed to `apply` takes a DataFrame as
its argument and returns a Series. `apply` combines the result for
each group together into a new DataFrame:
each group together into a new DataFrame.
.. versionchanged:: 1.3.0
>>> g[['B', 'C']].apply(lambda x: x.max() - x.min())
B C
The resulting dtype will reflect the return value of the passed ``func``.
>>> g[['B', 'C']].apply(lambda x: x.astype(float).max() - x.min())
B C
A
a 1 2
b 0 0
a 1.0 2.0
b 0.0 0.0
Example 3: The function passed to `apply` takes a DataFrame as
its argument and returns a scalar. `apply` combines the result for
Expand All @@ -210,12 +219,16 @@ class providing the base-class of operations.
Example 1: The function passed to `apply` takes a Series as
its argument and returns a Series. `apply` combines the result for
each group together into a new Series:
each group together into a new Series.
.. versionchanged:: 1.3.0
>>> g.apply(lambda x: x*2 if x.name == 'b' else x/2)
The resulting dtype will reflect the return value of the passed ``func``.
>>> g.apply(lambda x: x*2 if x.name == 'a' else x/2)
a 0.0
a 0.5
b 4.0
a 2.0
b 1.0
dtype: float64
Example 2: The function passed to `apply` takes a Series as
Expand Down Expand Up @@ -367,12 +380,17 @@ class providing the base-class of operations.
in the subframe. If f also supports application to the entire subframe,
then a fast path is used starting from the second chunk.
* f must not mutate groups. Mutation is not supported and may
produce unexpected results. See :ref:`udf-mutation` for more details.
produce unexpected results. See :ref:`gotchas.udf-mutation` for more details.
When using ``engine='numba'``, there will be no "fall back" behavior internally.
The group data and group index will be passed as numpy arrays to the JITed
user defined function, and no alternative execution attempts will be tried.
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the passed ``func``,
see the examples below.
Examples
--------
Expand Down Expand Up @@ -402,6 +420,20 @@ class providing the base-class of operations.
3 3 8.0
4 4 6.0
5 3 8.0
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the passed ``func``,
for example:
>>> grouped[['C', 'D']].transform(lambda x: x.astype(int).max())
C D
0 5 8
1 5 9
2 5 8
3 5 9
4 5 8
5 5 9
"""

_agg_template = """
Expand Down Expand Up @@ -469,12 +501,16 @@ class providing the base-class of operations.
When using ``engine='numba'``, there will be no "fall back" behavior internally.
The group data and group index will be passed as numpy arrays to the JITed
user defined function, and no alternative execution attempts will be tried.
{examples}
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
"""
.. versionchanged:: 1.3.0
The resulting dtype will reflect the return value of the passed ``func``,
see the examples below.
{examples}"""


@final
Expand Down Expand Up @@ -1232,9 +1268,6 @@ def _python_agg_general(self, func, *args, **kwargs):
assert result is not None
key = base.OutputKey(label=name, position=idx)

if is_numeric_dtype(obj.dtype):
result = maybe_downcast_numeric(result, obj.dtype)

if self.grouper._filter_empty_groups:
mask = counts.ravel() > 0

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4190,7 +4190,7 @@ def apply(
Notes
-----
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
Examples
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/shared_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
`agg` is an alias for `aggregate`. Use the alias.
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
A passed user-defined-function will be passed a Series for evaluation.
Expand Down Expand Up @@ -303,7 +303,7 @@
Notes
-----
Functions that mutate the passed object can produce unexpected
behavior or errors and are not supported. See :ref:`udf-mutation`
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
for more details.
Examples
Expand Down
Loading

0 comments on commit e9fd2cf

Please sign in to comment.