diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 03cac5e3df61..0e079f6302c9 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4742,31 +4742,41 @@ def top_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, + nulls_last: bool | None = None, + maintain_order: bool | None = None, ) -> DataFrame: """ - Return the `k` largest elements. - - If `descending=True` the smallest elements will be given. + Return the `k` largest rows. Parameters ---------- k Number of rows to return. by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. + Column(s) used to determine the top rows. + Accepts expression input. Strings are parsed as column names. descending - Return the `k` smallest. Top-k by multiple columns can be specified - per column by passing a sequence of booleans. + Consider the `k` smallest elements of the `by` column(s) (instead of the `k` + largest). This can be specified per column by passing a sequence of + booleans. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + See Also -------- bottom_k @@ -4833,31 +4843,41 @@ def bottom_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, + nulls_last: bool | None = None, + maintain_order: bool | None = None, ) -> DataFrame: """ - Return the `k` smallest elements. - - If `descending=True` the largest elements will be given. + Return the `k` smallest rows. Parameters ---------- k Number of rows to return. by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. + Column(s) used to determine the bottom rows. + Accepts expression input. Strings are parsed as column names. descending - Return the `k` largest. Bottom-k by multiple columns can be specified - per column by passing a sequence of booleans. + Consider the `k` largest elements of the `by` column(s) (instead of the `k` + smallest). This can be specified per column by passing a sequence of + booleans. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + See Also -------- top_k diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index d0058821fd7a..fa5a43266450 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2038,28 +2038,44 @@ def top_k( self, k: int | IntoExprColumn = 5, *, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: r""" Return the `k` largest elements. This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- k Number of elements to return. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- top_k_by @@ -2076,10 +2092,8 @@ def top_k( ... } ... ) >>> df.select( - ... [ - ... pl.col("value").top_k().alias("top_k"), - ... pl.col("value").bottom_k().alias("bottom_k"), - ... ] + ... pl.col("value").top_k().alias("top_k"), + ... pl.col("value").bottom_k().alias("bottom_k"), ... ) shape: (5, 2) ┌───────┬──────────┐ @@ -2094,6 +2108,37 @@ def top_k( │ 2 ┆ 98 │ └───────┴──────────┘ """ + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + k = parse_as_expression(k) return self._from_pyexpr( self._pyexpr.top_k( @@ -2110,35 +2155,51 @@ def top_k_by( k: int | IntoExprColumn = 5, *, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: r""" - Return elements corresponding to the `k` largest elements of the `by` column(s). + Return the elements corresponding to the `k` largest elements of the `by` column(s). This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. + Column(s) used to determine the largest elements. + Accepts expression input. Strings are parsed as column names. k Number of elements to return. descending - If `True`, consider the k smallest (instead of the k largest). Top-k by - multiple columns can be specified per column by passing a sequence of + Consider the `k` smallest elements of the `by` column(s) (instead of the `k` + largest). This can be specified per column by passing a sequence of booleans. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- top_k @@ -2224,7 +2285,38 @@ def top_k_by( │ Banana ┆ 6 ┆ 1 │ │ Banana ┆ 5 ┆ 2 │ └────────┴─────┴─────┘ - """ + """ # noqa: W505 + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `top_k_by` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `top_k_by` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `top_k_by` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + k = parse_as_expression(k) by = parse_as_list_of_expressions(by) if isinstance(descending, bool): @@ -2247,28 +2339,44 @@ def bottom_k( self, k: int | IntoExprColumn = 5, *, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: r""" Return the `k` smallest elements. This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- k Number of elements to return. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- top_k @@ -2283,10 +2391,8 @@ def bottom_k( ... } ... ) >>> df.select( - ... [ - ... pl.col("value").top_k().alias("top_k"), - ... pl.col("value").bottom_k().alias("bottom_k"), - ... ] + ... pl.col("value").top_k().alias("top_k"), + ... pl.col("value").bottom_k().alias("bottom_k"), ... ) shape: (5, 2) ┌───────┬──────────┐ @@ -2301,6 +2407,37 @@ def bottom_k( │ 2 ┆ 98 │ └───────┴──────────┘ """ + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + k = parse_as_expression(k) return self._from_pyexpr( self._pyexpr.bottom_k( @@ -2317,35 +2454,46 @@ def bottom_k_by( k: int | IntoExprColumn = 5, *, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: r""" - Return elements corresponding to the `k` smallest elements of `by` column(s). + Return the elements corresponding to the `k` smallest elements of the `by` column(s). This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- by - Column(s) included in sort order. + Column(s) used to determine the smallest elements. Accepts expression input. Strings are parsed as column names. k Number of elements to return. descending - If `True`, consider the k largest (instead of the k smallest). Bottom-k by - multiple columns can be specified per column by passing a sequence of + Consider the `k` largest elements of the `by` column(s) (instead of the `k` + smallest). This can be specified per column by passing a sequence of booleans. + nulls_last Place null values last. + maintain_order Whether the order should be maintained if elements are equal. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- top_k @@ -2431,7 +2579,38 @@ def bottom_k_by( │ Banana ┆ 5 ┆ 2 │ │ Banana ┆ 6 ┆ 1 │ └────────┴─────┴─────┘ - """ + """ # noqa: W505 + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `bottom_k_by` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `bottom_k_by` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `bottom_k_by` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + k = parse_as_expression(k) by = parse_as_list_of_expressions(by) if isinstance(descending, bool): diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 131491dbcc52..87d39c0d8634 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1361,34 +1361,49 @@ def top_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: """ - Return the `k` largest elements. - - If `descending=True` the smallest elements will be given. + Return the `k` largest rows. Parameters ---------- k Number of rows to return. by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. + Column(s) used to determine the top rows. + Accepts expression input. Strings are parsed as column names. descending - Return the `k` smallest. Top-k by multiple columns can be specified - per column by passing a sequence of booleans. + Consider the `k` smallest elements of the `by` column(s) (instead of the `k` + largest). This can be specified per column by passing a sequence of + booleans. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- bottom_k @@ -1432,6 +1447,37 @@ def top_k( │ c ┆ 1 │ └─────┴─────┘ """ + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `top_k` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + by = parse_as_list_of_expressions(by) if isinstance(descending, bool): descending = [descending] @@ -1450,34 +1496,49 @@ def bottom_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool = False, - maintain_order: bool = False, - multithreaded: bool = True, + nulls_last: bool | None = None, + maintain_order: bool | None = None, + multithreaded: bool | None = None, ) -> Self: """ - Return the `k` smallest elements. - - If `descending=True` the largest elements will be given. + Return the `k` smallest rows. Parameters ---------- k Number of rows to return. by - Column(s) included in sort order. Accepts expression input. - Strings are parsed as column names. + Column(s) used to determine the bottom rows. + Accepts expression input. Strings are parsed as column names. descending - Return the `k` largest. Bottom-k by multiple columns can be specified - per column by passing a sequence of booleans. + Consider the `k` largest elements of the `by` column(s) (instead of the `k` + smallest). This can be specified per column by passing a sequence of + booleans. + nulls_last Place null values last. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Null values will be considered lowest priority and will only be + included if `k` is larger than the number of non-null elements. + maintain_order Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + There will be no guarantees about the order of the output. + multithreaded Sort using multiple threads. + .. deprecated:: 0.20.31 + This parameter will be removed in the next breaking release. + Polars itself will determine whether to use multithreading or not. + See Also -------- top_k @@ -1521,6 +1582,37 @@ def bottom_k( │ b ┆ 2 │ └─────┴─────┘ """ + if nulls_last is not None: + issue_deprecation_warning( + "The `nulls_last` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " Null values will be considered lowest priority and will only be" + " included if `k` is larger than the number of non-null elements.", + version="0.20.31", + ) + else: + nulls_last = False + + if maintain_order is not None: + issue_deprecation_warning( + "The `maintain_order` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " There will be no guarantees about the order of the output.", + version="0.20.31", + ) + else: + maintain_order = False + + if multithreaded is not None: + issue_deprecation_warning( + "The `multithreaded` parameter for `bottom_k` is deprecated." + " It will be removed in the next breaking release." + " Polars itself will determine whether to use multithreading or not.", + version="0.20.31", + ) + else: + multithreaded = True + by = parse_as_list_of_expressions(by) if isinstance(descending, bool): descending = [descending] diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 102fbf7f1191..518d0666031e 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3400,13 +3400,13 @@ def sort( self._s.sort(descending, nulls_last, multithreaded) ) - def top_k(self, k: int | IntoExprColumn = 5) -> Series: + def top_k(self, k: int = 5) -> Series: r""" Return the `k` largest elements. This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- @@ -3430,13 +3430,13 @@ def top_k(self, k: int | IntoExprColumn = 5) -> Series: ] """ - def bottom_k(self, k: int | IntoExprColumn = 5) -> Series: + def bottom_k(self, k: int = 5) -> Series: r""" Return the `k` smallest elements. This has time complexity: - .. math:: O(n + k \log{n} - \frac{k}{2}) + .. math:: O(n + k \log{n}) Parameters ---------- diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index c6ccdf0a70ed..68dc0b7fc7f1 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -6,7 +6,6 @@ import pytest import polars as pl -from polars.exceptions import ComputeError from polars.testing import assert_frame_equal, assert_series_equal @@ -276,11 +275,6 @@ def test_sorted_flag() -> None: assert q.collect()["timestamp"].flags["SORTED_ASC"] - # top-k/bottom-k - df = pl.DataFrame({"foo": [56, 2, 3]}) - assert df.top_k(2, by="foo")["foo"].flags["SORTED_DESC"] - assert df.bottom_k(2, by="foo")["foo"].flags["SORTED_ASC"] - # ensure we don't panic for these types # struct pl.Series([{"a": 1}]).set_sorted(descending=True) @@ -321,317 +315,6 @@ def test_arg_sort_rank_nans() -> None: ).to_dict(as_series=False) == {"rank": [1.0, 2.0], "arg_sort": [0, 1]} -def test_top_k() -> None: - # expression - s = pl.Series("a", [3, 8, 1, 5, 2]) - - assert_series_equal(s.top_k(3), pl.Series("a", [8, 5, 3])) - assert_series_equal(s.bottom_k(4), pl.Series("a", [1, 2, 3, 5])) - - assert_series_equal(s.top_k(pl.Series([3])), pl.Series("a", [8, 5, 3])) - assert_series_equal(s.bottom_k(pl.Series([4])), pl.Series("a", [1, 2, 3, 5])) - - # 5886 - df = pl.DataFrame( - { - "test": [2, 4, 1, 3], - "val": [2, 4, 9, 3], - "bool_val": [False, True, True, False], - "str_value": ["d", "b", "a", "c"], - } - ) - assert_frame_equal( - df.select(pl.col("test").top_k(10)), - pl.DataFrame({"test": [4, 3, 2, 1]}), - ) - - assert_frame_equal( - df.select( - top_k=pl.col("test").top_k(pl.col("val").min()), - bottom_k=pl.col("test").bottom_k(pl.col("val").min()), - ), - pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}), - ) - - assert_frame_equal( - df.select( - pl.col("bool_val").top_k(2).alias("top_k"), - pl.col("bool_val").bottom_k(2).alias("bottom_k"), - ), - pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}), - ) - - assert_frame_equal( - df.select( - pl.col("str_value").top_k(2).alias("top_k"), - pl.col("str_value").bottom_k(2).alias("bottom_k"), - ), - pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}), - ) - - with pytest.raises(ComputeError, match="`k` must be set for `top_k`"): - df.select( - pl.col("bool_val").top_k(pl.lit(None)), - ) - - with pytest.raises(ComputeError, match="`k` must be a single value for `top_k`."): - df.select(pl.col("test").top_k(pl.lit(pl.Series("s", [1, 2])))) - - # dataframe - df = pl.DataFrame( - { - "a": [1, 2, 3, 4, 2, 2], - "b": [3, 2, 1, 4, 3, 2], - } - ) - - assert_frame_equal( - df.top_k(3, by=["a", "b"]), - pl.DataFrame({"a": [4, 3, 2], "b": [4, 1, 3]}), - ) - - assert_frame_equal( - df.top_k(3, by=["a", "b"], descending=True), - pl.DataFrame({"a": [1, 2, 2], "b": [3, 2, 2]}), - ) - assert_frame_equal( - df.bottom_k(4, by=["a", "b"], descending=True), - pl.DataFrame({"a": [4, 3, 2, 2], "b": [4, 1, 3, 2]}), - ) - - df2 = pl.DataFrame( - { - "a": [1, 2, 3, 4, 5, 6], - "b": [12, 11, 10, 9, 8, 7], - "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"], - } - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b").top_k_by("a", 2).name.suffix("_top_by_a"), - pl.col("a", "b").top_k_by("b", 2).name.suffix("_top_by_b"), - ), - pl.DataFrame( - { - "a_top_by_a": [6, 5], - "b_top_by_a": [7, 8], - "a_top_by_b": [1, 2], - "b_top_by_b": [12, 11], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b").top_k_by("a", 2, descending=True).name.suffix("_top_by_a"), - pl.col("a", "b").top_k_by("b", 2, descending=True).name.suffix("_top_by_b"), - ), - pl.DataFrame( - { - "a_top_by_a": [1, 2], - "b_top_by_a": [12, 11], - "a_top_by_b": [6, 5], - "b_top_by_b": [7, 8], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b").bottom_k_by("a", 2).name.suffix("_bottom_by_a"), - pl.col("a", "b").bottom_k_by("b", 2).name.suffix("_bottom_by_b"), - ), - pl.DataFrame( - { - "a_bottom_by_a": [1, 2], - "b_bottom_by_a": [12, 11], - "a_bottom_by_b": [6, 5], - "b_bottom_by_b": [7, 8], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b") - .bottom_k_by("a", 2, descending=True) - .name.suffix("_bottom_by_a"), - pl.col("a", "b") - .bottom_k_by("b", 2, descending=True) - .name.suffix("_bottom_by_b"), - ), - pl.DataFrame( - { - "a_bottom_by_a": [6, 5], - "b_bottom_by_a": [7, 8], - "a_bottom_by_b": [1, 2], - "b_bottom_by_b": [12, 11], - } - ), - ) - - assert_frame_equal( - df2.group_by("c", maintain_order=True) - .agg(pl.all().top_k_by("a", 2)) - .explode(pl.all().exclude("c")), - pl.DataFrame( - { - "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], - "a": [4, 3, 2, 6, 5], - "b": [9, 10, 11, 7, 8], - } - ), - ) - - assert_frame_equal( - df2.group_by("c", maintain_order=True) - .agg(pl.all().bottom_k_by("a", 2)) - .explode(pl.all().exclude("c")), - pl.DataFrame( - { - "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], - "a": [1, 3, 2, 5, 6], - "b": [12, 10, 11, 8, 7], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c").top_k_by(["c", "a"], 2).name.suffix("_top_by_ca"), - pl.col("a", "b", "c").top_k_by(["c", "b"], 2).name.suffix("_top_by_cb"), - ), - pl.DataFrame( - { - "a_top_by_ca": [2, 6], - "b_top_by_ca": [11, 7], - "c_top_by_ca": ["Orange", "Banana"], - "a_top_by_cb": [2, 5], - "b_top_by_cb": [11, 8], - "c_top_by_cb": ["Orange", "Banana"], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c") - .bottom_k_by(["c", "a"], 2) - .name.suffix("_bottom_by_ca"), - pl.col("a", "b", "c") - .bottom_k_by(["c", "b"], 2) - .name.suffix("_bottom_by_cb"), - ), - pl.DataFrame( - { - "a_bottom_by_ca": [1, 3], - "b_bottom_by_ca": [12, 10], - "c_bottom_by_ca": ["Apple", "Apple"], - "a_bottom_by_cb": [4, 3], - "b_bottom_by_cb": [9, 10], - "c_bottom_by_cb": ["Apple", "Apple"], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c") - .top_k_by(["c", "a"], 2, descending=[True, False]) - .name.suffix("_top_by_ca"), - pl.col("a", "b", "c") - .top_k_by(["c", "b"], 2, descending=[True, False]) - .name.suffix("_top_by_cb"), - ), - pl.DataFrame( - { - "a_top_by_ca": [4, 3], - "b_top_by_ca": [9, 10], - "c_top_by_ca": ["Apple", "Apple"], - "a_top_by_cb": [1, 3], - "b_top_by_cb": [12, 10], - "c_top_by_cb": ["Apple", "Apple"], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c") - .bottom_k_by(["c", "a"], 2, descending=[True, False]) - .name.suffix("_bottom_by_ca"), - pl.col("a", "b", "c") - .bottom_k_by(["c", "b"], 2, descending=[True, False]) - .name.suffix("_bottom_by_cb"), - ), - pl.DataFrame( - { - "a_bottom_by_ca": [2, 5], - "b_bottom_by_ca": [11, 8], - "c_bottom_by_ca": ["Orange", "Banana"], - "a_bottom_by_cb": [2, 6], - "b_bottom_by_cb": [11, 7], - "c_bottom_by_cb": ["Orange", "Banana"], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c") - .top_k_by(["c", "a"], 2, descending=[False, True]) - .name.suffix("_top_by_ca"), - pl.col("a", "b", "c") - .top_k_by(["c", "b"], 2, descending=[False, True]) - .name.suffix("_top_by_cb"), - ), - pl.DataFrame( - { - "a_top_by_ca": [2, 5], - "b_top_by_ca": [11, 8], - "c_top_by_ca": ["Orange", "Banana"], - "a_top_by_cb": [2, 6], - "b_top_by_cb": [11, 7], - "c_top_by_cb": ["Orange", "Banana"], - } - ), - ) - - assert_frame_equal( - df2.select( - pl.col("a", "b", "c") - .top_k_by(["c", "a"], 2, descending=[False, True]) - .name.suffix("_bottom_by_ca"), - pl.col("a", "b", "c") - .top_k_by(["c", "b"], 2, descending=[False, True]) - .name.suffix("_bottom_by_cb"), - ), - pl.DataFrame( - { - "a_bottom_by_ca": [2, 5], - "b_bottom_by_ca": [11, 8], - "c_bottom_by_ca": ["Orange", "Banana"], - "a_bottom_by_cb": [2, 6], - "b_bottom_by_cb": [11, 7], - "c_bottom_by_cb": ["Orange", "Banana"], - } - ), - ) - - with pytest.raises( - ValueError, - match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", - ): - df2.select(pl.all().top_k_by("a", 2, descending=[True, False])) - - with pytest.raises( - ValueError, - match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", - ): - df2.select(pl.all().bottom_k_by("a", 2, descending=[True, False])) - - def test_sorted_flag_unset_by_arithmetic_4937() -> None: df = pl.DataFrame( { @@ -908,20 +591,6 @@ def test_sort_descending() -> None: df.sort(["a", "b"], descending=[True]) -def test_top_k_descending() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - result = df.top_k(1, by=["a", "b"], descending=True) - expected = pl.DataFrame({"a": [1], "b": [4]}) - assert_frame_equal(result, expected) - result = df.top_k(1, by=["a", "b"], descending=[True, True]) - assert_frame_equal(result, expected) - with pytest.raises( - ValueError, - match=r"the length of `descending` \(1\) does not match the length of `by` \(2\)", - ): - df.top_k(1, by=["a", "b"], descending=[True]) - - def test_sort_by_descending() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = df.select(pl.col("a").sort_by(["a", "b"], descending=True)) @@ -987,12 +656,6 @@ def test_sort_top_k_fast_path() -> None: } -def test_top_k_9385() -> None: - assert pl.LazyFrame({"b": [True, False]}).sort(["b"]).slice(0, 1).collect()[ - "b" - ].to_list() == [False] - - def test_sorted_flag_partition_by() -> None: assert ( pl.DataFrame({"one": [1, 2, 3], "two": ["a", "a", "b"]}) diff --git a/py-polars/tests/unit/operations/test_top_k.py b/py-polars/tests/unit/operations/test_top_k.py new file mode 100644 index 000000000000..5c6416cd2b0e --- /dev/null +++ b/py-polars/tests/unit/operations/test_top_k.py @@ -0,0 +1,362 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_top_k() -> None: + # expression + s = pl.Series("a", [3, 8, 1, 5, 2]) + + assert_series_equal(s.top_k(3), pl.Series("a", [8, 5, 3])) + assert_series_equal(s.bottom_k(4), pl.Series("a", [1, 2, 3, 5])) + + # 5886 + df = pl.DataFrame( + { + "test": [2, 4, 1, 3], + "val": [2, 4, 9, 3], + "bool_val": [False, True, True, False], + "str_value": ["d", "b", "a", "c"], + } + ) + assert_frame_equal( + df.select(pl.col("test").top_k(10)), + pl.DataFrame({"test": [4, 3, 2, 1]}), + ) + + assert_frame_equal( + df.select( + top_k=pl.col("test").top_k(pl.col("val").min()), + bottom_k=pl.col("test").bottom_k(pl.col("val").min()), + ), + pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}), + ) + + assert_frame_equal( + df.select( + pl.col("bool_val").top_k(2).alias("top_k"), + pl.col("bool_val").bottom_k(2).alias("bottom_k"), + ), + pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}), + ) + + assert_frame_equal( + df.select( + pl.col("str_value").top_k(2).alias("top_k"), + pl.col("str_value").bottom_k(2).alias("bottom_k"), + ), + pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}), + ) + + with pytest.raises(pl.ComputeError, match="`k` must be set for `top_k`"): + df.select( + pl.col("bool_val").top_k(pl.lit(None)), + ) + + with pytest.raises( + pl.ComputeError, match="`k` must be a single value for `top_k`." + ): + df.select(pl.col("test").top_k(pl.lit(pl.Series("s", [1, 2])))) + + # dataframe + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 2, 2], + "b": [3, 2, 1, 4, 3, 2], + } + ) + + assert_frame_equal( + df.top_k(3, by=["a", "b"]), + pl.DataFrame({"a": [4, 3, 2], "b": [4, 1, 3]}), + ) + + assert_frame_equal( + df.top_k(3, by=["a", "b"], descending=True), + pl.DataFrame({"a": [1, 2, 2], "b": [3, 2, 2]}), + ) + assert_frame_equal( + df.bottom_k(4, by=["a", "b"], descending=True), + pl.DataFrame({"a": [4, 3, 2, 2], "b": [4, 1, 3, 2]}), + ) + + df2 = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6], + "b": [12, 11, 10, 9, 8, 7], + "c": ["Apple", "Orange", "Apple", "Apple", "Banana", "Banana"], + } + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").top_k_by("a", 2).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2).name.suffix("_top_by_b"), + ), + pl.DataFrame( + { + "a_top_by_a": [6, 5], + "b_top_by_a": [7, 8], + "a_top_by_b": [1, 2], + "b_top_by_b": [12, 11], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").top_k_by("a", 2, descending=True).name.suffix("_top_by_a"), + pl.col("a", "b").top_k_by("b", 2, descending=True).name.suffix("_top_by_b"), + ), + pl.DataFrame( + { + "a_top_by_a": [1, 2], + "b_top_by_a": [12, 11], + "a_top_by_b": [6, 5], + "b_top_by_b": [7, 8], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b").bottom_k_by("a", 2).name.suffix("_bottom_by_a"), + pl.col("a", "b").bottom_k_by("b", 2).name.suffix("_bottom_by_b"), + ), + pl.DataFrame( + { + "a_bottom_by_a": [1, 2], + "b_bottom_by_a": [12, 11], + "a_bottom_by_b": [6, 5], + "b_bottom_by_b": [7, 8], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b") + .bottom_k_by("a", 2, descending=True) + .name.suffix("_bottom_by_a"), + pl.col("a", "b") + .bottom_k_by("b", 2, descending=True) + .name.suffix("_bottom_by_b"), + ), + pl.DataFrame( + { + "a_bottom_by_a": [6, 5], + "b_bottom_by_a": [7, 8], + "a_bottom_by_b": [1, 2], + "b_bottom_by_b": [12, 11], + } + ), + ) + + assert_frame_equal( + df2.group_by("c", maintain_order=True) + .agg(pl.all().top_k_by("a", 2)) + .explode(pl.all().exclude("c")), + pl.DataFrame( + { + "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], + "a": [4, 3, 2, 6, 5], + "b": [9, 10, 11, 7, 8], + } + ), + ) + + assert_frame_equal( + df2.group_by("c", maintain_order=True) + .agg(pl.all().bottom_k_by("a", 2)) + .explode(pl.all().exclude("c")), + pl.DataFrame( + { + "c": ["Apple", "Apple", "Orange", "Banana", "Banana"], + "a": [1, 3, 2, 5, 6], + "b": [12, 10, 11, 8, 7], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c").top_k_by(["c", "a"], 2).name.suffix("_top_by_ca"), + pl.col("a", "b", "c").top_k_by(["c", "b"], 2).name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [2, 6], + "b_top_by_ca": [11, 7], + "c_top_by_ca": ["Orange", "Banana"], + "a_top_by_cb": [2, 5], + "b_top_by_cb": [11, 8], + "c_top_by_cb": ["Orange", "Banana"], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .bottom_k_by(["c", "a"], 2) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .bottom_k_by(["c", "b"], 2) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [1, 3], + "b_bottom_by_ca": [12, 10], + "c_bottom_by_ca": ["Apple", "Apple"], + "a_bottom_by_cb": [4, 3], + "b_bottom_by_cb": [9, 10], + "c_bottom_by_cb": ["Apple", "Apple"], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, descending=[True, False]) + .name.suffix("_top_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, descending=[True, False]) + .name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [4, 3], + "b_top_by_ca": [9, 10], + "c_top_by_ca": ["Apple", "Apple"], + "a_top_by_cb": [1, 3], + "b_top_by_cb": [12, 10], + "c_top_by_cb": ["Apple", "Apple"], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .bottom_k_by(["c", "a"], 2, descending=[True, False]) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .bottom_k_by(["c", "b"], 2, descending=[True, False]) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [2, 5], + "b_bottom_by_ca": [11, 8], + "c_bottom_by_ca": ["Orange", "Banana"], + "a_bottom_by_cb": [2, 6], + "b_bottom_by_cb": [11, 7], + "c_bottom_by_cb": ["Orange", "Banana"], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, descending=[False, True]) + .name.suffix("_top_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, descending=[False, True]) + .name.suffix("_top_by_cb"), + ), + pl.DataFrame( + { + "a_top_by_ca": [2, 5], + "b_top_by_ca": [11, 8], + "c_top_by_ca": ["Orange", "Banana"], + "a_top_by_cb": [2, 6], + "b_top_by_cb": [11, 7], + "c_top_by_cb": ["Orange", "Banana"], + } + ), + ) + + assert_frame_equal( + df2.select( + pl.col("a", "b", "c") + .top_k_by(["c", "a"], 2, descending=[False, True]) + .name.suffix("_bottom_by_ca"), + pl.col("a", "b", "c") + .top_k_by(["c", "b"], 2, descending=[False, True]) + .name.suffix("_bottom_by_cb"), + ), + pl.DataFrame( + { + "a_bottom_by_ca": [2, 5], + "b_bottom_by_ca": [11, 8], + "c_bottom_by_ca": ["Orange", "Banana"], + "a_bottom_by_cb": [2, 6], + "b_bottom_by_cb": [11, 7], + "c_bottom_by_cb": ["Orange", "Banana"], + } + ), + ) + + with pytest.raises( + ValueError, + match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", + ): + df2.select(pl.all().top_k_by("a", 2, descending=[True, False])) + + with pytest.raises( + ValueError, + match=r"the length of `descending` \(2\) does not match the length of `by` \(1\)", + ): + df2.select(pl.all().bottom_k_by("a", 2, descending=[True, False])) + + +def test_top_k_descending() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.top_k(1, by=["a", "b"], descending=True) + expected = pl.DataFrame({"a": [1], "b": [4]}) + assert_frame_equal(result, expected) + result = df.top_k(1, by=["a", "b"], descending=[True, True]) + assert_frame_equal(result, expected) + with pytest.raises( + ValueError, + match=r"the length of `descending` \(1\) does not match the length of `by` \(2\)", + ): + df.top_k(1, by=["a", "b"], descending=[True]) + + +def test_top_k_9385() -> None: + lf = pl.LazyFrame({"b": [True, False]}) + result = lf.sort(["b"]).slice(0, 1) + assert result.collect()["b"].to_list() == [False] + + +def test_top_k_sorted_flag() -> None: + # top-k/bottom-k + df = pl.DataFrame({"foo": [56, 2, 3]}) + assert df.top_k(2, by="foo")["foo"].flags["SORTED_DESC"] + assert df.bottom_k(2, by="foo")["foo"].flags["SORTED_ASC"] + + +def test_top_k_empty() -> None: + df = pl.DataFrame({"test": []}) + + assert_frame_equal(df.select([pl.col("test").top_k(2)]), df) + + +def test_top_k_nulls_last_deprecated() -> None: + with pytest.deprecated_call(): + pl.col("a").top_k(5, nulls_last=True) + + +def test_top_k_maintain_order_deprecated() -> None: + with pytest.deprecated_call(): + pl.col("a").top_k(5, maintain_order=True) + + +def test_top_k_multithreaded_deprecated() -> None: + with pytest.deprecated_call(): + pl.col("a").top_k(5, multithreaded=True) diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py index f5895a8a7c55..acf650185c64 100644 --- a/py-polars/tests/unit/test_empty.py +++ b/py-polars/tests/unit/test_empty.py @@ -13,12 +13,6 @@ def test_empty_str_concat_lit() -> None: } -def test_top_k_empty() -> None: - df = pl.DataFrame({"test": []}) - - assert_frame_equal(df.select([pl.col("test").top_k(2)]), df) - - def test_empty_cross_join() -> None: a = pl.LazyFrame(schema={"a": pl.Int32}) b = pl.LazyFrame(schema={"b": pl.Int32})