Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(python): Add Series.cut, deprecate pl.cut #7058

Merged
merged 1 commit into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/modify_select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Manipulation/selection
Series.clip_max
Series.clip_min
Series.clone
Series.cut
Series.drop_nans
Series.drop_nulls
Series.explode
Expand Down
47 changes: 10 additions & 37 deletions py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import contextlib
import warnings
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Iterable, Sequence, overload

from polars import internals as pli
from polars.datatypes import Categorical, Date, Float64, PolarsDataType
from polars.datatypes import Date, PolarsDataType
from polars.utils import (
_datetime_to_pl_timestamp,
_timedelta_to_pl_duration,
Expand Down Expand Up @@ -505,6 +506,9 @@ def cut(
"""
Bin values into discrete values.

.. deprecated:: 0.16.8
`pl.cut(series, ...)` has been deprecated; use `series.cut(...)`

Parameters
----------
s
Expand Down Expand Up @@ -550,43 +554,12 @@ def cut(
└──────┴─────────────┴──────────────┘

"""
var_nm = s.name

cuts_df = pli.DataFrame(
[
pli.Series(
name=break_point_label, values=bins, dtype=Float64
).extend_constant(float("inf"), 1)
]
)

if labels:
if len(labels) != len(bins) + 1:
raise ValueError("expected more labels")
cuts_df = cuts_df.with_columns(pli.Series(name=category_label, values=labels))
else:
cuts_df = cuts_df.with_columns(
pli.format(
"({}, {}]",
pli.col(break_point_label).shift_and_fill(1, float("-inf")),
pli.col(break_point_label),
).alias(category_label)
)

cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical))

result = (
s.cast(Float64)
.sort()
.to_frame()
.join_asof(
cuts_df,
left_on=var_nm,
right_on=break_point_label,
strategy="forward",
)
warnings.warn(
"`pl.cut(series)` has been deprecated; use `series.cut()`",
category=DeprecationWarning,
stacklevel=2,
)
return result
return s.cut(bins, labels, break_point_label, category_label)


@overload
Expand Down
93 changes: 93 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,99 @@ def to_dummies(self, separator: str = "_") -> pli.DataFrame:
"""
return pli.wrap_df(self._s.to_dummies(separator))

def cut(
self,
bins: list[float],
labels: list[str] | None = None,
break_point_label: str = "break_point",
category_label: str = "category",
) -> pli.DataFrame:
"""
Bin values into discrete values.
Parameters
----------
bins
Bins to create.
labels
Labels to assign to the bins. If given the length of labels must be
len(bins) + 1.
break_point_label
Name given to the breakpoint column.
category_label
Name given to the category column.
Returns
-------
DataFrame
Warnings
--------
This functionality is experimental and may change without it being considered a
breaking change.
Examples
--------
>>> a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
>>> a.cut(bins=[-1, 1])
shape: (12, 3)
┌──────┬─────────────┬──────────────┐
│ a ┆ break_point ┆ category │
│ --- ┆ --- ┆ --- │
│ f64 ┆ f64 ┆ cat │
╞══════╪═════════════╪══════════════╡
│ -3.0 ┆ -1.0 ┆ (-inf, -1.0] │
│ -2.5 ┆ -1.0 ┆ (-inf, -1.0] │
│ -2.0 ┆ -1.0 ┆ (-inf, -1.0] │
│ -1.5 ┆ -1.0 ┆ (-inf, -1.0] │
│ ... ┆ ... ┆ ... │
│ 1.0 ┆ 1.0 ┆ (-1.0, 1.0] │
│ 1.5 ┆ inf ┆ (1.0, inf] │
│ 2.0 ┆ inf ┆ (1.0, inf] │
│ 2.5 ┆ inf ┆ (1.0, inf] │
└──────┴─────────────┴──────────────┘
"""
var_nm = self.name

cuts_df = pli.DataFrame(
[
pli.Series(
name=break_point_label, values=bins, dtype=Float64
).extend_constant(float("inf"), 1)
]
)

if labels:
if len(labels) != len(bins) + 1:
raise ValueError("expected more labels")
cuts_df = cuts_df.with_columns(
pli.Series(name=category_label, values=labels)
)
else:
cuts_df = cuts_df.with_columns(
pli.format(
"({}, {}]",
pli.col(break_point_label).shift_and_fill(1, float("-inf")),
pli.col(break_point_label),
).alias(category_label)
)

cuts_df = cuts_df.with_columns(pli.col(category_label).cast(Categorical))

result = (
self.cast(Float64)
.sort()
.to_frame()
.join_asof(
cuts_df,
left_on=var_nm,
right_on=break_point_label,
strategy="forward",
)
)
return result

def value_counts(self, sort: bool = False) -> pli.DataFrame:
"""
Count the unique values in a Series.
Expand Down
36 changes: 4 additions & 32 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,38 +93,10 @@ def test_all_any_horizontally() -> None:
assert_frame_equal(result, expected)


def test_cut() -> None:
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
out = pl.cut(a, bins=[-1, 1])

assert out.shape == (12, 3)
assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == {
"a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0],
"break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
"category": [
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
],
}

# test cut on integers #4939
inf = float("inf")
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert pl.cut(ser, bins=[-1, 1]).rows() == [
(0.0, 1.0, "(-1.0, 1.0]"),
(1.0, 1.0, "(-1.0, 1.0]"),
(2.0, inf, "(1.0, inf]"),
(3.0, inf, "(1.0, inf]"),
(4.0, inf, "(1.0, inf]"),
]
def test_cut_deprecated() -> None:
with pytest.deprecated_call():
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
pl.cut(a, bins=[-1, 1])


def test_null_handling_correlation() -> None:
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2523,3 +2523,37 @@ def test_from_epoch_seq_input() -> None:
expected = pl.Series([datetime(2006, 5, 17, 15, 34, 4)])
result = pl.from_epoch(seq_input)
assert_series_equal(result, expected)


def test_cut() -> None:
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
out = a.cut(bins=[-1, 1])

assert out.shape == (12, 3)
assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == {
"a": [-3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0],
"break_point": [-1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
"category": [
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-inf, -1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
"(-1.0, 1.0]",
],
}

# test cut on integers #4939
inf = float("inf")
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert ser.cut(bins=[-1, 1]).rows() == [
(0.0, 1.0, "(-1.0, 1.0]"),
(1.0, 1.0, "(-1.0, 1.0]"),
(2.0, inf, "(1.0, inf]"),
(3.0, inf, "(1.0, inf]"),
(4.0, inf, "(1.0, inf]"),
]