Skip to content

Commit

Permalink
refactor(python): Add Series.cut, deprecate pl.cut (#7058)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj authored Feb 25, 2023
1 parent 4ac9163 commit d039b1a
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 69 deletions.
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/modify_select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Manipulation/selection
Series.clip_max
Series.clip_min
Series.clone
Series.cut
Series.drop_nans
Series.drop_nulls
Series.explode
Expand Down
47 changes: 10 additions & 37 deletions py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import contextlib
import warnings
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Iterable, Sequence, overload

from polars import internals as pli
from polars.datatypes import Categorical, Date, Float64, PolarsDataType
from polars.datatypes import Date, PolarsDataType
from polars.utils import (
_datetime_to_pl_timestamp,
_timedelta_to_pl_duration,
Expand Down Expand Up @@ -505,6 +506,9 @@ def cut(
"""
Bin values into discrete values.
.. deprecated:: 0.16.8
`pl.cut(series, ...)` has been deprecated; use `series.cut(...)`
Parameters
----------
s
Expand Down Expand Up @@ -550,43 +554,12 @@ def cut(
└──────┴─────────────┴──────────────┘
"""
var_nm = s.name

cuts_df = pli.DataFrame(
[
pli.Series(
name=break_point_label, values=bins, dtype=Float64
).extend_constant(float("inf"), 1)
]
)

if labels:
if len(labels) != len(bins) + 1:
raise ValueError("expected more labels")
cuts_df = cuts_df.with_columns(pli.Series(name=category_label, values=labels))
else:
cuts_df = cuts_df.with_columns(
pli.format(
"({}, {}]",
pli.col(break_point_label).shift_and_fill(1, float("-inf")),
pli.col(break_point_label),
).alias(category_label)
)

cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical))

result = (
s.cast(Float64)
.sort()
.to_frame()
.join_asof(
cuts_df,
left_on=var_nm,
right_on=break_point_label,
strategy="forward",
)
warnings.warn(
"`pl.cut(series)` has been deprecated; use `series.cut()`",
category=DeprecationWarning,
stacklevel=2,
)
return result
return s.cut(bins, labels, break_point_label, category_label)


@overload
Expand Down
93 changes: 93 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,99 @@ def to_dummies(self, separator: str = "_") -> pli.DataFrame:
"""
return pli.wrap_df(self._s.to_dummies(separator))

def cut(
self,
bins: list[float],
labels: list[str] | None = None,
break_point_label: str = "break_point",
category_label: str = "category",
) -> pli.DataFrame:
"""
Bin values into discrete values.
Parameters
----------
bins
Bins to create.
labels
Labels to assign to the bins. If given the length of labels must be
len(bins) + 1.
break_point_label
Name given to the breakpoint column.
category_label
Name given to the category column.
Returns
-------
DataFrame
Warnings
--------
This functionality is experimental and may change without it being considered a
breaking change.
Examples
--------
>>> a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
>>> a.cut(bins=[-1, 1])
shape: (12, 3)
┌──────┬─────────────┬──────────────┐
│ a ┆ break_point ┆ category │
│ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ cat │
╞══════╪═════════════╪══════════════╡
│ -3.0 ┆ -1.0 ┆ (-inf, -1.0] │
│ -2.5 ┆ -1.0 ┆ (-inf, -1.0] │
│ -2.0 ┆ -1.0 ┆ (-inf, -1.0] │
│ -1.5 ┆ -1.0 ┆ (-inf, -1.0] │
│ ... ┆ ... ┆ ... │
│ 1.0 ┆ 1.0 ┆ (-1.0, 1.0] │
│ 1.5 ┆ inf ┆ (1.0, inf] │
│ 2.0 ┆ inf ┆ (1.0, inf] │
│ 2.5 ┆ inf ┆ (1.0, inf] │
└──────┴─────────────┴──────────────┘
"""
var_nm = self.name

cuts_df = pli.DataFrame(
[
pli.Series(
name=break_point_label, values=bins, dtype=Float64
).extend_constant(float("inf"), 1)
]
)

if labels:
if len(labels) != len(bins) + 1:
raise ValueError("expected more labels")
cuts_df = cuts_df.with_columns(
pli.Series(name=category_label, values=labels)
)
else:
cuts_df = cuts_df.with_columns(
pli.format(
"({}, {}]",
pli.col(break_point_label).shift_and_fill(1, float("-inf")),
pli.col(break_point_label),
).alias(category_label)
)

cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical))

result = (
self.cast(Float64)
.sort()
.to_frame()
.join_asof(
cuts_df,
left_on=var_nm,
right_on=break_point_label,
strategy="forward",
)
)
return result

def value_counts(self, sort: bool = False) -> pli.DataFrame:
"""
Count the unique values in a Series.
Expand Down
36 changes: 4 additions & 32 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,38 +93,10 @@ def test_all_any_horizontally() -> None:
assert_frame_equal(result, expected)


def test_cut() -> None:
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
out = pl.cut(a, bins=[-1, 1])

assert out.shape == (12, 3)
assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == {
"a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0],
"break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
"category": [
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
],
}

# test cut on integers #4939
inf = float("inf")
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert pl.cut(ser, bins=[-1, 1]).rows() == [
(0.0, 1.0, "(-1.0, 1.0]"),
(1.0, 1.0, "(-1.0, 1.0]"),
(2.0, inf, "(1.0, inf]"),
(3.0, inf, "(1.0, inf]"),
(4.0, inf, "(1.0, inf]"),
]
def test_cut_deprecated() -> None:
with pytest.deprecated_call():
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
pl.cut(a, bins=[-1, 1])


def test_null_handling_correlation() -> None:
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,3 +2534,37 @@ def test_from_epoch_seq_input() -> None:
expected = pl.Series([datetime(2006, 5, 17, 15, 34, 4)])
result = pl.from_epoch(seq_input)
assert_series_equal(result, expected)


def test_cut() -> None:
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
out = a.cut(bins=[-1, 1])

assert out.shape == (12, 3)
assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == {
"a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0],
"break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
"category": [
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
],
}

# test cut on integers #4939
inf = float("inf")
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert ser.cut(bins=[-1, 1]).rows() == [
(0.0, 1.0, "(-1.0, 1.0]"),
(1.0, 1.0, "(-1.0, 1.0]"),
(2.0, inf, "(1.0, inf]"),
(3.0, inf, "(1.0, inf]"),
(4.0, inf, "(1.0, inf]"),
]

0 comments on commit d039b1a

Please sign in to comment.