diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index bce6a735b7b07..8864469eaf858 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -342,6 +342,7 @@ Other ^^^^^ - Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` incorrectly raising ``AssertionError`` instead of ``ValueError`` when invalid parameter combinations are passed (:issue:`36045`) - Bug in :meth:`DataFrame.replace` and :meth:`Series.replace` with numeric values and string ``to_replace`` (:issue:`34789`) +- Bug in :meth:`Series.transform` would give incorrect results or raise when the argument ``func`` was dictionary (:issue:`35811`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/aggregation.py b/pandas/core/aggregation.py index 7ca68d8289bd5..8b74fe01d0dc0 100644 --- a/pandas/core/aggregation.py +++ b/pandas/core/aggregation.py @@ -18,9 +18,10 @@ Union, ) -from pandas._typing import AggFuncType, FrameOrSeries, Label +from pandas._typing import AggFuncType, Axis, FrameOrSeries, Label from pandas.core.dtypes.common import is_dict_like, is_list_like +from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries from pandas.core.base import SpecificationError import pandas.core.common as com @@ -384,3 +385,98 @@ def validate_func_kwargs( if not columns: raise TypeError(no_arg_message) return columns, func + + +def transform( + obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs, +) -> FrameOrSeries: + """ + Transform a DataFrame or Series + + Parameters + ---------- + obj : DataFrame or Series + Object to compute the transform on. + func : string, function, list, or dictionary + Function(s) to compute the transform with. + axis : {0 or 'index', 1 or 'columns'} + Axis along which the function is applied: + + * 0 or 'index': apply function to each column. + * 1 or 'columns': apply function to each row. + + Returns + ------- + DataFrame or Series + Result of applying ``func`` along the given axis of the + Series or DataFrame. + + Raises + ------ + ValueError + If the transform function fails or does not transform. + """ + from pandas.core.reshape.concat import concat + + is_series = obj.ndim == 1 + + if obj._get_axis_number(axis) == 1: + assert not is_series + return transform(obj.T, func, 0, *args, **kwargs).T + + if isinstance(func, list): + if is_series: + func = {com.get_callable_name(v) or v: v for v in func} + else: + func = {col: func for col in obj} + + if isinstance(func, dict): + if not is_series: + cols = sorted(set(func.keys()) - set(obj.columns)) + if len(cols) > 0: + raise SpecificationError(f"Column(s) {cols} do not exist") + + if any(isinstance(v, dict) for v in func.values()): + # GH 15931 - deprecation of renaming keys + raise SpecificationError("nested renamer is not supported") + + results = {} + for name, how in func.items(): + colg = obj._gotitem(name, ndim=1) + try: + results[name] = transform(colg, how, 0, *args, **kwargs) + except Exception as e: + if str(e) == "Function did not transform": + raise e + + # combine results + if len(results) == 0: + raise ValueError("Transform function failed") + return concat(results, axis=1) + + # func is either str or callable + try: + if isinstance(func, str): + result = obj._try_aggregate_string_function(func, *args, **kwargs) + else: + f = obj._get_cython_func(func) + if f and not args and not kwargs: + result = getattr(obj, f)() + else: + try: + result = obj.apply(func, args=args, **kwargs) + except Exception: + result = func(obj, *args, **kwargs) + except Exception: + raise ValueError("Transform function failed") + + # Functions that transform may return empty Series/DataFrame + # when the dtype is not appropriate + if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty: + raise ValueError("Transform function failed") + if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals( + obj.index + ): + raise ValueError("Function did not transform") + + return result diff --git a/pandas/core/base.py b/pandas/core/base.py index 1926803d8f04b..a688302b99724 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -4,7 +4,7 @@ import builtins import textwrap -from typing import Any, Dict, FrozenSet, List, Optional, Union +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Union import numpy as np @@ -560,7 +560,7 @@ def _aggregate_multiple_funcs(self, arg, _axis): ) from err return result - def _get_cython_func(self, arg: str) -> Optional[str]: + def _get_cython_func(self, arg: Callable) -> Optional[str]: """ if we define an internal function for this argument, return it """ diff --git a/pandas/core/frame.py b/pandas/core/frame.py index b03593ad8afe1..1e5360f39a75e 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -45,6 +45,7 @@ from pandas._libs import algos as libalgos, lib, properties from pandas._libs.lib import no_default from pandas._typing import ( + AggFuncType, ArrayLike, Axes, Axis, @@ -116,7 +117,7 @@ from pandas.core import algorithms, common as com, nanops, ops from pandas.core.accessor import CachedAccessor -from pandas.core.aggregation import reconstruct_func, relabel_result +from pandas.core.aggregation import reconstruct_func, relabel_result, transform from pandas.core.arrays import Categorical, ExtensionArray from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin as DatetimeLikeArray from pandas.core.arrays.sparse import SparseFrameAccessor @@ -7461,15 +7462,16 @@ def _aggregate(self, arg, axis=0, *args, **kwargs): agg = aggregate @doc( - NDFrame.transform, + _shared_docs["transform"], klass=_shared_doc_kwargs["klass"], axis=_shared_doc_kwargs["axis"], ) - def transform(self, func, axis=0, *args, **kwargs) -> DataFrame: - axis = self._get_axis_number(axis) - if axis == 1: - return self.T.transform(func, *args, **kwargs).T - return super().transform(func, *args, **kwargs) + def transform( + self, func: AggFuncType, axis: Axis = 0, *args, **kwargs + ) -> DataFrame: + result = transform(self, func, axis, *args, **kwargs) + assert isinstance(result, DataFrame) + return result def apply(self, func, axis=0, raw=False, result_type=None, args=(), **kwds): """ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index fffd2e068ebcf..9ed9db801d0a8 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -10648,80 +10648,6 @@ def ewm( times=times, ) - @doc(klass=_shared_doc_kwargs["klass"], axis="") - def transform(self, func, *args, **kwargs): - """ - Call ``func`` on self producing a {klass} with transformed values. - - Produced {klass} will have same axis length as self. - - Parameters - ---------- - func : function, str, list or dict - Function to use for transforming the data. If a function, must either - work when passed a {klass} or when passed to {klass}.apply. - - Accepted combinations are: - - - function - - string function name - - list of functions and/or function names, e.g. ``[np.exp, 'sqrt']`` - - dict of axis labels -> functions, function names or list of such. - {axis} - *args - Positional arguments to pass to `func`. - **kwargs - Keyword arguments to pass to `func`. - - Returns - ------- - {klass} - A {klass} that must have the same length as self. - - Raises - ------ - ValueError : If the returned {klass} has a different length than self. - - See Also - -------- - {klass}.agg : Only perform aggregating type operations. - {klass}.apply : Invoke function on a {klass}. - - Examples - -------- - >>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}}) - >>> df - A B - 0 0 1 - 1 1 2 - 2 2 3 - >>> df.transform(lambda x: x + 1) - A B - 0 1 2 - 1 2 3 - 2 3 4 - - Even though the resulting {klass} must have the same length as the - input {klass}, it is possible to provide several input functions: - - >>> s = pd.Series(range(3)) - >>> s - 0 0 - 1 1 - 2 2 - dtype: int64 - >>> s.transform([np.sqrt, np.exp]) - sqrt exp - 0 0.000000 1.000000 - 1 1.000000 2.718282 - 2 1.414214 7.389056 - """ - result = self.agg(func, *args, **kwargs) - if is_scalar(result) or len(result) != len(self): - raise ValueError("transforms cannot produce aggregated results") - - return result - # ---------------------------------------------------------------------- # Misc methods diff --git a/pandas/core/series.py b/pandas/core/series.py index 6cbd93135a2ca..632b93cdcf24b 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -25,6 +25,7 @@ from pandas._libs import lib, properties, reshape, tslibs from pandas._libs.lib import no_default from pandas._typing import ( + AggFuncType, ArrayLike, Axis, DtypeObj, @@ -89,6 +90,7 @@ from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas.core.indexing import check_bool_indexer from pandas.core.internals import SingleBlockManager +from pandas.core.shared_docs import _shared_docs from pandas.core.sorting import ensure_key_mapped from pandas.core.strings import StringMethods from pandas.core.tools.datetimes import to_datetime @@ -4081,14 +4083,16 @@ def aggregate(self, func=None, axis=0, *args, **kwargs): agg = aggregate @doc( - NDFrame.transform, + _shared_docs["transform"], klass=_shared_doc_kwargs["klass"], axis=_shared_doc_kwargs["axis"], ) - def transform(self, func, axis=0, *args, **kwargs): - # Validate the axis parameter - self._get_axis_number(axis) - return super().transform(func, *args, **kwargs) + def transform( + self, func: AggFuncType, axis: Axis = 0, *args, **kwargs + ) -> FrameOrSeriesUnion: + from pandas.core.aggregation import transform + + return transform(self, func, axis, *args, **kwargs) def apply(self, func, convert_dtype=True, args=(), **kwds): """ diff --git a/pandas/core/shared_docs.py b/pandas/core/shared_docs.py index 0aaccb47efc44..244ee3aa298db 100644 --- a/pandas/core/shared_docs.py +++ b/pandas/core/shared_docs.py @@ -257,3 +257,72 @@ 1 b B E 3 2 c B E 5 """ + +_shared_docs[ + "transform" +] = """\ +Call ``func`` on self producing a {klass} with transformed values. + +Produced {klass} will have same axis length as self. + +Parameters +---------- +func : function, str, list or dict + Function to use for transforming the data. If a function, must either + work when passed a {klass} or when passed to {klass}.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.exp, 'sqrt']`` + - dict of axis labels -> functions, function names or list of such. +{axis} +*args + Positional arguments to pass to `func`. +**kwargs + Keyword arguments to pass to `func`. + +Returns +------- +{klass} + A {klass} that must have the same length as self. + +Raises +------ +ValueError : If the returned {klass} has a different length than self. + +See Also +-------- +{klass}.agg : Only perform aggregating type operations. +{klass}.apply : Invoke function on a {klass}. + +Examples +-------- +>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}}) +>>> df + A B +0 0 1 +1 1 2 +2 2 3 +>>> df.transform(lambda x: x + 1) + A B +0 1 2 +1 2 3 +2 3 4 + +Even though the resulting {klass} must have the same length as the +input {klass}, it is possible to provide several input functions: + +>>> s = pd.Series(range(3)) +>>> s +0 0 +1 1 +2 2 +dtype: int64 +>>> s.transform([np.sqrt, np.exp]) + sqrt exp +0 0.000000 1.000000 +1 1.000000 2.718282 +2 1.414214 7.389056 +""" diff --git a/pandas/tests/frame/apply/test_frame_transform.py b/pandas/tests/frame/apply/test_frame_transform.py index 3a345215482ed..346e60954fc13 100644 --- a/pandas/tests/frame/apply/test_frame_transform.py +++ b/pandas/tests/frame/apply/test_frame_transform.py @@ -1,72 +1,203 @@ import operator +import re import numpy as np import pytest -import pandas as pd +from pandas import DataFrame, MultiIndex import pandas._testing as tm +from pandas.core.base import SpecificationError +from pandas.core.groupby.base import transformation_kernels from pandas.tests.frame.common import zip_frames -def test_agg_transform(axis, float_frame): - other_axis = 1 if axis in {0, "index"} else 0 +def test_transform_ufunc(axis, float_frame): + # GH 35964 + with np.errstate(all="ignore"): + f_sqrt = np.sqrt(float_frame) + result = float_frame.transform(np.sqrt, axis=axis) + expected = f_sqrt + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("op", transformation_kernels) +def test_transform_groupby_kernel(axis, float_frame, op): + # GH 35964 + if op == "cumcount": + pytest.xfail("DataFrame.cumcount does not exist") + if op == "tshift": + pytest.xfail("Only works on time index and is deprecated") + if axis == 1 or axis == "columns": + pytest.xfail("GH 36308: groupby.transform with axis=1 is broken") + + args = [0.0] if op == "fillna" else [] + if axis == 0 or axis == "index": + ones = np.ones(float_frame.shape[0]) + else: + ones = np.ones(float_frame.shape[1]) + expected = float_frame.groupby(ones, axis=axis).transform(op, *args) + result = float_frame.transform(op, axis, *args) + tm.assert_frame_equal(result, expected) + +@pytest.mark.parametrize( + "ops, names", [([np.sqrt], ["sqrt"]), ([np.abs, np.sqrt], ["absolute", "sqrt"])] +) +def test_transform_list(axis, float_frame, ops, names): + # GH 35964 + other_axis = 1 if axis in {0, "index"} else 0 with np.errstate(all="ignore"): + expected = zip_frames([op(float_frame) for op in ops], axis=other_axis) + if axis in {0, "index"}: + expected.columns = MultiIndex.from_product([float_frame.columns, names]) + else: + expected.index = MultiIndex.from_product([float_frame.index, names]) + result = float_frame.transform(ops, axis=axis) + tm.assert_frame_equal(result, expected) - f_abs = np.abs(float_frame) - f_sqrt = np.sqrt(float_frame) - # ufunc - result = float_frame.transform(np.sqrt, axis=axis) - expected = f_sqrt.copy() - tm.assert_frame_equal(result, expected) - - result = float_frame.transform(np.sqrt, axis=axis) - tm.assert_frame_equal(result, expected) - - # list-like - expected = f_sqrt.copy() - if axis in {0, "index"}: - expected.columns = pd.MultiIndex.from_product( - [float_frame.columns, ["sqrt"]] - ) - else: - expected.index = pd.MultiIndex.from_product([float_frame.index, ["sqrt"]]) - result = float_frame.transform([np.sqrt], axis=axis) - tm.assert_frame_equal(result, expected) - - # multiple items in list - # these are in the order as if we are applying both - # functions per series and then concatting - expected = zip_frames([f_abs, f_sqrt], axis=other_axis) - if axis in {0, "index"}: - expected.columns = pd.MultiIndex.from_product( - [float_frame.columns, ["absolute", "sqrt"]] - ) - else: - expected.index = pd.MultiIndex.from_product( - [float_frame.index, ["absolute", "sqrt"]] - ) - result = float_frame.transform([np.abs, "sqrt"], axis=axis) - tm.assert_frame_equal(result, expected) +def test_transform_dict(axis, float_frame): + # GH 35964 + if axis == 0 or axis == "index": + e = float_frame.columns[0] + expected = float_frame[[e]].transform(np.abs) + else: + e = float_frame.index[0] + expected = float_frame.iloc[[0]].transform(np.abs) + result = float_frame.transform({e: np.abs}, axis=axis) + tm.assert_frame_equal(result, expected) -def test_transform_and_agg_err(axis, float_frame): - # cannot both transform and agg - msg = "transforms cannot produce aggregated results" - with pytest.raises(ValueError, match=msg): - float_frame.transform(["max", "min"], axis=axis) +@pytest.mark.parametrize("use_apply", [True, False]) +def test_transform_udf(axis, float_frame, use_apply): + # GH 35964 + # transform uses UDF either via apply or passing the entire DataFrame + def func(x): + # transform is using apply iff x is not a DataFrame + if use_apply == isinstance(x, DataFrame): + # Force transform to fallback + raise ValueError + return x + 1 - msg = "cannot combine transform and aggregation operations" - with pytest.raises(ValueError, match=msg): - with np.errstate(all="ignore"): - float_frame.transform(["max", "sqrt"], axis=axis) + result = float_frame.transform(func, axis=axis) + expected = float_frame + 1 + tm.assert_frame_equal(result, expected) @pytest.mark.parametrize("method", ["abs", "shift", "pct_change", "cumsum", "rank"]) def test_transform_method_name(method): # GH 19760 - df = pd.DataFrame({"A": [-1, 2]}) + df = DataFrame({"A": [-1, 2]}) result = df.transform(method) expected = operator.methodcaller(method)(df) tm.assert_frame_equal(result, expected) + + +def test_transform_and_agg_err(axis, float_frame): + # GH 35964 + # cannot both transform and agg + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + float_frame.transform(["max", "min"], axis=axis) + + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + float_frame.transform(["max", "sqrt"], axis=axis) + + +def test_agg_dict_nested_renaming_depr(): + df = DataFrame({"A": range(5), "B": 5}) + + # nested renaming + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + # mypy identifies the argument as an invalid type + df.transform({"A": {"foo": "min"}, "B": {"bar": "max"}}) + + +def test_transform_reducer_raises(all_reductions): + # GH 35964 + op = all_reductions + df = DataFrame({"A": [1, 2, 3]}) + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + df.transform(op) + with pytest.raises(ValueError, match=msg): + df.transform([op]) + with pytest.raises(ValueError, match=msg): + df.transform({"A": op}) + with pytest.raises(ValueError, match=msg): + df.transform({"A": [op]}) + + +# mypy doesn't allow adding lists of different types +# https://github.com/python/mypy/issues/5492 +@pytest.mark.parametrize("op", [*transformation_kernels, lambda x: x + 1]) +def test_transform_bad_dtype(op): + # GH 35964 + df = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms + if op in ("backfill", "shift", "pad", "bfill", "ffill"): + pytest.xfail("Transform function works on any datatype") + msg = "Transform function failed" + with pytest.raises(ValueError, match=msg): + df.transform(op) + with pytest.raises(ValueError, match=msg): + df.transform([op]) + with pytest.raises(ValueError, match=msg): + df.transform({"A": op}) + with pytest.raises(ValueError, match=msg): + df.transform({"A": [op]}) + + +@pytest.mark.parametrize("op", transformation_kernels) +def test_transform_partial_failure(op): + # GH 35964 + wont_fail = ["ffill", "bfill", "fillna", "pad", "backfill", "shift"] + if op in wont_fail: + pytest.xfail("Transform kernel is successful on all dtypes") + if op == "cumcount": + pytest.xfail("transform('cumcount') not implemented") + if op == "tshift": + pytest.xfail("Only works on time index; deprecated") + + # Using object makes most transform kernels fail + df = DataFrame({"A": 3 * [object], "B": [1, 2, 3]}) + + expected = df[["B"]].transform([op]) + result = df.transform([op]) + tm.assert_equal(result, expected) + + expected = df[["B"]].transform({"B": op}) + result = df.transform({"B": op}) + tm.assert_equal(result, expected) + + expected = df[["B"]].transform({"B": [op]}) + result = df.transform({"B": [op]}) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("use_apply", [True, False]) +def test_transform_passes_args(use_apply): + # GH 35964 + # transform uses UDF either via apply or passing the entire DataFrame + expected_args = [1, 2] + expected_kwargs = {"c": 3} + + def f(x, a, b, c): + # transform is using apply iff x is not a DataFrame + if use_apply == isinstance(x, DataFrame): + # Force transform to fallback + raise ValueError + assert [a, b] == expected_args + assert c == expected_kwargs["c"] + return x + + DataFrame([1]).transform(f, 0, *expected_args, **expected_kwargs) + + +def test_transform_missing_columns(axis): + # GH 35964 + df = DataFrame({"A": [1, 2], "B": [3, 4]}) + match = re.escape("Column(s) ['C'] do not exist") + with pytest.raises(SpecificationError, match=match): + df.transform({"C": "cumsum"}) diff --git a/pandas/tests/series/apply/test_series_apply.py b/pandas/tests/series/apply/test_series_apply.py index b948317f32062..827f466e23106 100644 --- a/pandas/tests/series/apply/test_series_apply.py +++ b/pandas/tests/series/apply/test_series_apply.py @@ -209,8 +209,8 @@ def test_transform(self, string_series): f_abs = np.abs(string_series) # ufunc - expected = f_sqrt.copy() result = string_series.apply(np.sqrt) + expected = f_sqrt.copy() tm.assert_series_equal(result, expected) # list-like @@ -219,6 +219,9 @@ def test_transform(self, string_series): expected.columns = ["sqrt"] tm.assert_frame_equal(result, expected) + result = string_series.apply(["sqrt"]) + tm.assert_frame_equal(result, expected) + # multiple items in list # these are in the order as if we are applying both functions per # series and then concatting diff --git a/pandas/tests/series/apply/test_series_transform.py b/pandas/tests/series/apply/test_series_transform.py index 8bc3d2dc4d0db..0842674da2a7d 100644 --- a/pandas/tests/series/apply/test_series_transform.py +++ b/pandas/tests/series/apply/test_series_transform.py @@ -1,50 +1,90 @@ import numpy as np import pytest -import pandas as pd +from pandas import DataFrame, Series, concat import pandas._testing as tm +from pandas.core.base import SpecificationError +from pandas.core.groupby.base import transformation_kernels -def test_transform(string_series): - # transforming functions - +def test_transform_ufunc(string_series): + # GH 35964 with np.errstate(all="ignore"): f_sqrt = np.sqrt(string_series) - f_abs = np.abs(string_series) - # ufunc - result = string_series.transform(np.sqrt) - expected = f_sqrt.copy() - tm.assert_series_equal(result, expected) + # ufunc + result = string_series.transform(np.sqrt) + expected = f_sqrt.copy() + tm.assert_series_equal(result, expected) - # list-like - result = string_series.transform([np.sqrt]) - expected = f_sqrt.to_frame().copy() - expected.columns = ["sqrt"] - tm.assert_frame_equal(result, expected) - result = string_series.transform([np.sqrt]) - tm.assert_frame_equal(result, expected) +@pytest.mark.parametrize("op", transformation_kernels) +def test_transform_groupby_kernel(string_series, op): + # GH 35964 + if op == "cumcount": + pytest.xfail("Series.cumcount does not exist") + if op == "tshift": + pytest.xfail("Only works on time index and is deprecated") + + args = [0.0] if op == "fillna" else [] + ones = np.ones(string_series.shape[0]) + expected = string_series.groupby(ones).transform(op, *args) + result = string_series.transform(op, 0, *args) + tm.assert_series_equal(result, expected) - result = string_series.transform(["sqrt"]) - tm.assert_frame_equal(result, expected) - # multiple items in list - # these are in the order as if we are applying both functions per - # series and then concatting - expected = pd.concat([f_sqrt, f_abs], axis=1) - result = string_series.transform(["sqrt", "abs"]) - expected.columns = ["sqrt", "abs"] +@pytest.mark.parametrize( + "ops, names", [([np.sqrt], ["sqrt"]), ([np.abs, np.sqrt], ["absolute", "sqrt"])] +) +def test_transform_list(string_series, ops, names): + # GH 35964 + with np.errstate(all="ignore"): + expected = concat([op(string_series) for op in ops], axis=1) + expected.columns = names + result = string_series.transform(ops) tm.assert_frame_equal(result, expected) -def test_transform_and_agg_error(string_series): +def test_transform_dict(string_series): + # GH 35964 + with np.errstate(all="ignore"): + expected = concat([np.sqrt(string_series), np.abs(string_series)], axis=1) + expected.columns = ["foo", "bar"] + result = string_series.transform({"foo": np.sqrt, "bar": np.abs}) + tm.assert_frame_equal(result, expected) + + +def test_transform_udf(axis, string_series): + # GH 35964 + # via apply + def func(x): + if isinstance(x, Series): + raise ValueError + return x + 1 + + result = string_series.transform(func) + expected = string_series + 1 + tm.assert_series_equal(result, expected) + + # via map Series -> Series + def func(x): + if not isinstance(x, Series): + raise ValueError + return x + 1 + + result = string_series.transform(func) + expected = string_series + 1 + tm.assert_series_equal(result, expected) + + +def test_transform_wont_agg(string_series): + # GH 35964 # we are trying to transform with an aggregator - msg = "transforms cannot produce aggregated results" + msg = "Function did not transform" with pytest.raises(ValueError, match=msg): string_series.transform(["min", "max"]) - msg = "cannot combine transform and aggregation operations" + msg = "Function did not transform" with pytest.raises(ValueError, match=msg): with np.errstate(all="ignore"): string_series.transform(["sqrt", "max"]) @@ -52,8 +92,74 @@ def test_transform_and_agg_error(string_series): def test_transform_none_to_type(): # GH34377 - df = pd.DataFrame({"a": [None]}) - - msg = "DataFrame constructor called with incompatible data and dtype" - with pytest.raises(TypeError, match=msg): + df = DataFrame({"a": [None]}) + msg = "Transform function failed" + with pytest.raises(ValueError, match=msg): df.transform({"a": int}) + + +def test_transform_reducer_raises(all_reductions): + # GH 35964 + op = all_reductions + s = Series([1, 2, 3]) + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + s.transform(op) + with pytest.raises(ValueError, match=msg): + s.transform([op]) + with pytest.raises(ValueError, match=msg): + s.transform({"A": op}) + with pytest.raises(ValueError, match=msg): + s.transform({"A": [op]}) + + +# mypy doesn't allow adding lists of different types +# https://github.com/python/mypy/issues/5492 +@pytest.mark.parametrize("op", [*transformation_kernels, lambda x: x + 1]) +def test_transform_bad_dtype(op): + # GH 35964 + s = Series(3 * [object]) # Series that will fail on most transforms + if op in ("backfill", "shift", "pad", "bfill", "ffill"): + pytest.xfail("Transform function works on any datatype") + msg = "Transform function failed" + with pytest.raises(ValueError, match=msg): + s.transform(op) + with pytest.raises(ValueError, match=msg): + s.transform([op]) + with pytest.raises(ValueError, match=msg): + s.transform({"A": op}) + with pytest.raises(ValueError, match=msg): + s.transform({"A": [op]}) + + +@pytest.mark.parametrize("use_apply", [True, False]) +def test_transform_passes_args(use_apply): + # GH 35964 + # transform uses UDF either via apply or passing the entire Series + expected_args = [1, 2] + expected_kwargs = {"c": 3} + + def f(x, a, b, c): + # transform is using apply iff x is not a Series + if use_apply == isinstance(x, Series): + # Force transform to fallback + raise ValueError + assert [a, b] == expected_args + assert c == expected_kwargs["c"] + return x + + Series([1]).transform(f, 0, *expected_args, **expected_kwargs) + + +def test_transform_axis_1_raises(): + # GH 35964 + msg = "No axis named 1 for object type Series" + with pytest.raises(ValueError, match=msg): + Series([1]).transform("sum", axis=1) + + +def test_transform_nested_renamer(): + # GH 35964 + match = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=match): + Series([1]).transform({"A": {"B": ["sum"]}})