Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG/CLN: Decouple Series/DataFrame.transform #35964

Merged
merged 13 commits into from
Sep 12, 2020
100 changes: 100 additions & 0 deletions pandas/core/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pandas._typing import AggFuncType, FrameOrSeries, Label

from pandas.core.dtypes.common import is_dict_like, is_list_like
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries

from pandas.core.base import SpecificationError
import pandas.core.common as com
Expand Down Expand Up @@ -384,3 +385,102 @@ def validate_func_kwargs(
if not columns:
raise TypeError(no_arg_message)
return columns, func


def transform(
obj: FrameOrSeries,
func: Union[str, List, Dict, Callable],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have an alias for this anywhere?

axis: int,
*args,
**kwargs,
) -> FrameOrSeries:
"""
Transform a DataFrame or Series

Parameters
----------
obj : DataFrame or Series
Object to compute the transform on.
func : string, function, list, or dictionary
Function(s) to compute the transform with.
axis : {0 or 'index', 1 or 'columns'}
Axis along which the function is applied:

* 0 or 'index': apply function to each column.
* 1 or 'columns': apply function to each row.

Returns
-------
DataFrame or Series
Result of applying ``func`` along the given axis of the
Series or DataFrame.

Raises
------
ValueError
If the transform function fails or does not transform.
"""
is_series = obj.ndim == 1

if obj._get_axis_number(axis) == 1:
assert not is_series
jreback marked this conversation as resolved.
Show resolved Hide resolved
return transform(obj.T, func, 0, *args, **kwargs).T

if isinstance(func, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably use is_list_like here

if is_series:
func = {com.get_callable_name(v) or v: v for v in func}
else:
func = {col: func for col in obj}

if isinstance(func, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably use is_dict_like here

if not is_series:
cols = sorted(set(func.keys()) - set(obj.columns))
if len(cols) > 0:
raise SpecificationError(f"Column(s) {cols} do not exist")

if any(isinstance(v, dict) for v in func.values()):
# GH 15931 - deprecation of renaming keys
raise SpecificationError("nested renamer is not supported")

results = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you type results

for name, how in func.items():
colg = obj._gotitem(name, ndim=1)
try:
results[name] = transform(colg, how, 0, *args, **kwargs)
except Exception as e:
if str(e) == "Function did not transform":
raise e

# combine results
if len(results) == 0:
raise ValueError("Transform function failed")
from pandas.core.reshape.concat import concat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this at the top of the function


return concat(results, axis=1)

# func is either str or callable
try:
jreback marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(func, str):
result = obj._try_aggregate_string_function(func, *args, **kwargs)
else:
f = obj._get_cython_func(func)
if f and not args and not kwargs:
result = getattr(obj, f)()
else:
try:
result = obj.apply(func, args=args, **kwargs)
except Exception:
result = func(obj, *args, **kwargs)
except Exception:
raise ValueError("Transform function failed")

# Functions that transform may return empty Series/DataFrame
# when the dtype is not appropriate
if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty:
raise ValueError("Transform function failed")
if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals(
obj.index
):
raise ValueError("Function did not transform")

return result
4 changes: 2 additions & 2 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import builtins
import textwrap
from typing import Any, Dict, FrozenSet, List, Optional, Union
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -560,7 +560,7 @@ def _aggregate_multiple_funcs(self, arg, _axis):
) from err
return result

def _get_cython_func(self, arg: str) -> Optional[str]:
def _get_cython_func(self, arg: Callable) -> Optional[str]:
"""
if we define an internal function for this argument, return it
"""
Expand Down
9 changes: 4 additions & 5 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@

from pandas.core import algorithms, common as com, nanops, ops
from pandas.core.accessor import CachedAccessor
from pandas.core.aggregation import reconstruct_func, relabel_result
from pandas.core.aggregation import reconstruct_func, relabel_result, transform
from pandas.core.arrays import Categorical, ExtensionArray
from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin as DatetimeLikeArray
from pandas.core.arrays.sparse import SparseFrameAccessor
Expand Down Expand Up @@ -7463,10 +7463,9 @@ def _aggregate(self, arg, axis=0, *args, **kwargs):
axis=_shared_doc_kwargs["axis"],
)
def transform(self, func, axis=0, *args, **kwargs) -> "DataFrame":
axis = self._get_axis_number(axis)
if axis == 1:
return self.T.transform(func, *args, **kwargs).T
return super().transform(func, *args, **kwargs)
result = transform(self, func, axis, *args, **kwargs)
assert isinstance(result, DataFrame)
return result

def apply(self, func, axis=0, raw=False, result_type=None, args=(), **kwds):
"""
Expand Down
6 changes: 1 addition & 5 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10729,11 +10729,7 @@ def transform(self, func, *args, **kwargs):
1 1.000000 2.718282
2 1.414214 7.389056
"""
result = self.agg(func, *args, **kwargs)
if is_scalar(result) or len(result) != len(self):
raise ValueError("transforms cannot produce aggregated results")

return result
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this? or is the doc-string used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just move the doc-string to shared_docs?


# ----------------------------------------------------------------------
# Misc methods
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4085,9 +4085,9 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
axis=_shared_doc_kwargs["axis"],
)
def transform(self, func, axis=0, *args, **kwargs):
# Validate the axis parameter
self._get_axis_number(axis)
return super().transform(func, *args, **kwargs)
from pandas.core.aggregation import transform

return transform(self, func, axis, *args, **kwargs)

def apply(self, func, convert_dtype=True, args=(), **kwds):
"""
Expand Down
46 changes: 1 addition & 45 deletions pandas/tests/frame/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections import OrderedDict
from datetime import datetime
from itertools import chain
import operator
import warnings

import numpy as np
Expand All @@ -14,6 +13,7 @@
import pandas._testing as tm
from pandas.core.apply import frame_apply
from pandas.core.base import SpecificationError
from pandas.tests.frame.common import zip_frames


@pytest.fixture
Expand Down Expand Up @@ -1058,25 +1058,6 @@ def test_consistency_for_boxed(self, box, int_frame_const_col):
tm.assert_frame_equal(result, expected)


def zip_frames(frames, axis=1):
"""
take a list of frames, zip them together under the
assumption that these all have the first frames' index/columns.

Returns
-------
new_frame : DataFrame
"""
if axis == 1:
columns = frames[0].columns
zipped = [f.loc[:, c] for c in columns for f in frames]
return pd.concat(zipped, axis=1)
else:
index = frames[0].index
zipped = [f.loc[i, :] for i in index for f in frames]
return pd.DataFrame(zipped)


class TestDataFrameAggregate:
def test_agg_transform(self, axis, float_frame):
other_axis = 1 if axis in {0, "index"} else 0
Expand All @@ -1087,10 +1068,7 @@ def test_agg_transform(self, axis, float_frame):
f_sqrt = np.sqrt(float_frame)

# ufunc
result = float_frame.transform(np.sqrt, axis=axis)
expected = f_sqrt.copy()
tm.assert_frame_equal(result, expected)

result = float_frame.apply(np.sqrt, axis=axis)
tm.assert_frame_equal(result, expected)

Expand All @@ -1110,9 +1088,6 @@ def test_agg_transform(self, axis, float_frame):
)
tm.assert_frame_equal(result, expected)

result = float_frame.transform([np.sqrt], axis=axis)
tm.assert_frame_equal(result, expected)

# multiple items in list
# these are in the order as if we are applying both
# functions per series and then concatting
Expand All @@ -1128,38 +1103,19 @@ def test_agg_transform(self, axis, float_frame):
)
tm.assert_frame_equal(result, expected)

result = float_frame.transform([np.abs, "sqrt"], axis=axis)
tm.assert_frame_equal(result, expected)

def test_transform_and_agg_err(self, axis, float_frame):
# cannot both transform and agg
msg = "transforms cannot produce aggregated results"
with pytest.raises(ValueError, match=msg):
float_frame.transform(["max", "min"], axis=axis)

msg = "cannot combine transform and aggregation operations"
with pytest.raises(ValueError, match=msg):
with np.errstate(all="ignore"):
float_frame.agg(["max", "sqrt"], axis=axis)

with pytest.raises(ValueError, match=msg):
with np.errstate(all="ignore"):
float_frame.transform(["max", "sqrt"], axis=axis)

df = pd.DataFrame({"A": range(5), "B": 5})

def f():
with np.errstate(all="ignore"):
df.agg({"A": ["abs", "sum"], "B": ["mean", "max"]}, axis=axis)

@pytest.mark.parametrize("method", ["abs", "shift", "pct_change", "cumsum", "rank"])
def test_transform_method_name(self, method):
# GH 19760
df = pd.DataFrame({"A": [-1, 2]})
result = df.transform(method)
expected = operator.methodcaller(method)(df)
tm.assert_frame_equal(result, expected)

def test_demo(self):
# demonstration tests
df = pd.DataFrame({"A": range(5), "B": 5})
Expand Down
Loading