From fd45f17976969670e0712a100f803d76054c5e50 Mon Sep 17 00:00:00 2001 From: Jeroen van Zundert Date: Mon, 20 Feb 2023 20:19:39 +0000 Subject: [PATCH] feat(python): Add Expr.arg_true We have `pl.arg_where` and `Series.arg_true`; this PR adds `Expr.arg_true`. I think `arg_true` makes more sense as name then `arg_where` when operating on the expression. --- .../reference/expressions/modify_select.rst | 1 + py-polars/polars/internals/expr/expr.py | 35 +++++++++++++++++++ py-polars/polars/internals/lazy_functions.py | 4 +++ py-polars/tests/unit/test_exprs.py | 7 ++++ 4 files changed, 47 insertions(+) diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 86649d5695a7..a713f005d9bd 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -9,6 +9,7 @@ Manipulation/selection Expr.append Expr.arg_sort Expr.argsort + Expr.arg_true Expr.backward_fill Expr.cast Expr.ceil diff --git a/py-polars/polars/internals/expr/expr.py b/py-polars/polars/internals/expr/expr.py index 7851f1356ac6..2f3742a9f503 100644 --- a/py-polars/polars/internals/expr/expr.py +++ b/py-polars/polars/internals/expr/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import math import os import random @@ -40,6 +41,9 @@ sphinx_accessor, ) +with contextlib.suppress(ImportError): # Module not available when building docs + from polars.polars import arg_where as py_arg_where + if TYPE_CHECKING: import sys @@ -418,6 +422,37 @@ def all(self) -> Self: """ return self._from_pyexpr(self._pyexpr.all()) + def arg_true(self) -> Self: + """ + Return indices where expression evaluates `True`. + + .. warning:: + Modifies number of rows returned, so will fail in combination with other + expressions. Use as only expression in `select` / `with_columns`. + + Examples + -------- + >>> df = pl.DataFrame({"a": [1, 1, 2, 1]}) + >>> df.select((pl.col("a") == 1).arg_true()) + shape: (3, 1) + ┌─────┐ + │ a │ + │ --- │ + │ u32 │ + ╞═════╡ + │ 0 │ + │ 1 │ + │ 3 │ + └─────┘ + + See Also + -------- + Series.arg_true : Return indices where Series is True + pl.arg_where + + """ + return self._from_pyexpr(py_arg_where(self._pyexpr)) + def sqrt(self) -> Self: """ Compute the square root of the elements. diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index c1a225504be1..a694486936cf 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -2912,6 +2912,10 @@ def arg_where( 3 ] + See Also + -------- + Series.arg_true : Return indices where Series is True + """ if eager: if not isinstance(condition, pli.Series): diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 46efc87fa9e9..749d55186e40 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -701,3 +701,10 @@ def test_exclude_invalid_input(input: tuple[Any, ...]) -> None: df = pl.DataFrame(schema=["a", "b", "c"]) with pytest.raises(TypeError): df.select(pl.all().exclude(*input)) + + +def test_arg_true() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 1]}) + res = df.select((pl.col("a") == 1).arg_true()) + expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)]) + assert_frame_equal(res, expected)