From d039b1ab8605a10f05194bdae33704cfdd358f44 Mon Sep 17 00:00:00 2001 From: J van Zundert Date: Sat, 25 Feb 2023 19:11:44 +0000 Subject: [PATCH] refactor(python): Add Series.cut, deprecate pl.cut (#7058) --- .../source/reference/series/modify_select.rst | 1 + py-polars/polars/internals/functions.py | 47 ++-------- py-polars/polars/internals/series/series.py | 93 +++++++++++++++++++ py-polars/tests/unit/test_functions.py | 36 +------ py-polars/tests/unit/test_series.py | 34 +++++++ 5 files changed, 142 insertions(+), 69 deletions(-) diff --git a/py-polars/docs/source/reference/series/modify_select.rst b/py-polars/docs/source/reference/series/modify_select.rst index 88756c063563..c72a769b6b2c 100644 --- a/py-polars/docs/source/reference/series/modify_select.rst +++ b/py-polars/docs/source/reference/series/modify_select.rst @@ -17,6 +17,7 @@ Manipulation/selection Series.clip_max Series.clip_min Series.clone + Series.cut Series.drop_nans Series.drop_nulls Series.explode diff --git a/py-polars/polars/internals/functions.py b/py-polars/polars/internals/functions.py index d8faa50f428a..ea32416a39ec 100644 --- a/py-polars/polars/internals/functions.py +++ b/py-polars/polars/internals/functions.py @@ -1,11 +1,12 @@ from __future__ import annotations import contextlib +import warnings from datetime import date, datetime, timedelta from typing import TYPE_CHECKING, Iterable, Sequence, overload from polars import internals as pli -from polars.datatypes import Categorical, Date, Float64, PolarsDataType +from polars.datatypes import Date, PolarsDataType from polars.utils import ( _datetime_to_pl_timestamp, _timedelta_to_pl_duration, @@ -505,6 +506,9 @@ def cut( """ Bin values into discrete values. + .. deprecated:: 0.16.8 + `pl.cut(series, ...)` has been deprecated; use `series.cut(...)` + Parameters ---------- s @@ -550,43 +554,12 @@ def cut( └──────┴─────────────┴──────────────┘ """ - var_nm = s.name - - cuts_df = pli.DataFrame( - [ - pli.Series( - name=break_point_label, values=bins, dtype=Float64 - ).extend_constant(float("inf"), 1) - ] - ) - - if labels: - if len(labels) != len(bins) + 1: - raise ValueError("expected more labels") - cuts_df = cuts_df.with_columns(pli.Series(name=category_label, values=labels)) - else: - cuts_df = cuts_df.with_columns( - pli.format( - "({}, {}]", - pli.col(break_point_label).shift_and_fill(1, float("-inf")), - pli.col(break_point_label), - ).alias(category_label) - ) - - cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical)) - - result = ( - s.cast(Float64) - .sort() - .to_frame() - .join_asof( - cuts_df, - left_on=var_nm, - right_on=break_point_label, - strategy="forward", - ) + warnings.warn( + "`pl.cut(series)` has been deprecated; use `series.cut()`", + category=DeprecationWarning, + stacklevel=2, ) - return result + return s.cut(bins, labels, break_point_label, category_label) @overload diff --git a/py-polars/polars/internals/series/series.py b/py-polars/polars/internals/series/series.py index 4f55f5bb66d6..fa08115f2f91 100644 --- a/py-polars/polars/internals/series/series.py +++ b/py-polars/polars/internals/series/series.py @@ -1416,6 +1416,99 @@ def to_dummies(self, separator: str = "_") -> pli.DataFrame: """ return pli.wrap_df(self._s.to_dummies(separator)) + def cut( + self, + bins: list[float], + labels: list[str] | None = None, + break_point_label: str = "break_point", + category_label: str = "category", + ) -> pli.DataFrame: + """ + Bin values into discrete values. + + Parameters + ---------- + bins + Bins to create. + labels + Labels to assign to the bins. If given the length of labels must be + len(bins) + 1. + break_point_label + Name given to the breakpoint column. + category_label + Name given to the category column. + + Returns + ------- + DataFrame + + Warnings + -------- + This functionality is experimental and may change without it being considered a + breaking change. + + Examples + -------- + >>> a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)]) + >>> a.cut(bins=[-1, 1]) + shape: (12, 3) + ┌──────┬─────────────┬──────────────┐ + │ a ┆ break_point ┆ category │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ f64 ┆ cat │ + ╞══════╪═════════════╪══════════════╡ + │ -3.0 ┆ -1.0 ┆ (-inf, -1.0] │ + │ -2.5 ┆ -1.0 ┆ (-inf, -1.0] │ + │ -2.0 ┆ -1.0 ┆ (-inf, -1.0] │ + │ -1.5 ┆ -1.0 ┆ (-inf, -1.0] │ + │ ... ┆ ... ┆ ... │ + │ 1.0 ┆ 1.0 ┆ (-1.0, 1.0] │ + │ 1.5 ┆ inf ┆ (1.0, inf] │ + │ 2.0 ┆ inf ┆ (1.0, inf] │ + │ 2.5 ┆ inf ┆ (1.0, inf] │ + └──────┴─────────────┴──────────────┘ + + """ + var_nm = self.name + + cuts_df = pli.DataFrame( + [ + pli.Series( + name=break_point_label, values=bins, dtype=Float64 + ).extend_constant(float("inf"), 1) + ] + ) + + if labels: + if len(labels) != len(bins) + 1: + raise ValueError("expected more labels") + cuts_df = cuts_df.with_columns( + pli.Series(name=category_label, values=labels) + ) + else: + cuts_df = cuts_df.with_columns( + pli.format( + "({}, {}]", + pli.col(break_point_label).shift_and_fill(1, float("-inf")), + pli.col(break_point_label), + ).alias(category_label) + ) + + cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical)) + + result = ( + self.cast(Float64) + .sort() + .to_frame() + .join_asof( + cuts_df, + left_on=var_nm, + right_on=break_point_label, + strategy="forward", + ) + ) + return result + def value_counts(self, sort: bool = False) -> pli.DataFrame: """ Count the unique values in a Series. diff --git a/py-polars/tests/unit/test_functions.py b/py-polars/tests/unit/test_functions.py index e745923aa229..6816846e7d85 100644 --- a/py-polars/tests/unit/test_functions.py +++ b/py-polars/tests/unit/test_functions.py @@ -93,38 +93,10 @@ def test_all_any_horizontally() -> None: assert_frame_equal(result, expected) -def test_cut() -> None: - a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)]) - out = pl.cut(a, bins=[-1, 1]) - - assert out.shape == (12, 3) - assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == { - "a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0], - "break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0], - "category": [ - "(-inf, -1.0]", - "(-inf, -1.0]", - "(-inf, -1.0]", - "(-inf, -1.0]", - "(-inf, -1.0]", - "(-1.0, 1.0]", - "(-1.0, 1.0]", - "(-1.0, 1.0]", - "(-1.0, 1.0]", - ], - } - - # test cut on integers #4939 - inf = float("inf") - df = pl.DataFrame({"a": list(range(5))}) - ser = df.select("a").to_series() - assert pl.cut(ser, bins=[-1, 1]).rows() == [ - (0.0, 1.0, "(-1.0, 1.0]"), - (1.0, 1.0, "(-1.0, 1.0]"), - (2.0, inf, "(1.0, inf]"), - (3.0, inf, "(1.0, inf]"), - (4.0, inf, "(1.0, inf]"), - ] +def test_cut_deprecated() -> None: + with pytest.deprecated_call(): + a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)]) + pl.cut(a, bins=[-1, 1]) def test_null_handling_correlation() -> None: diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 5dca4c0db030..23e5e8ec6fee 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -2534,3 +2534,37 @@ def test_from_epoch_seq_input() -> None: expected = pl.Series([datetime(2006, 5, 17, 15, 34, 4)]) result = pl.from_epoch(seq_input) assert_series_equal(result, expected) + + +def test_cut() -> None: + a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)]) + out = a.cut(bins=[-1, 1]) + + assert out.shape == (12, 3) + assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == { + "a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0], + "break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0], + "category": [ + "(-inf, -1.0]", + "(-inf, -1.0]", + "(-inf, -1.0]", + "(-inf, -1.0]", + "(-inf, -1.0]", + "(-1.0, 1.0]", + "(-1.0, 1.0]", + "(-1.0, 1.0]", + "(-1.0, 1.0]", + ], + } + + # test cut on integers #4939 + inf = float("inf") + df = pl.DataFrame({"a": list(range(5))}) + ser = df.select("a").to_series() + assert ser.cut(bins=[-1, 1]).rows() == [ + (0.0, 1.0, "(-1.0, 1.0]"), + (1.0, 1.0, "(-1.0, 1.0]"), + (2.0, inf, "(1.0, inf]"), + (3.0, inf, "(1.0, inf]"), + (4.0, inf, "(1.0, inf]"), + ]