diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 86649d5695a7..a713f005d9bd 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -9,6 +9,7 @@ Manipulation/selection Expr.append Expr.arg_sort Expr.argsort + Expr.arg_true Expr.backward_fill Expr.cast Expr.ceil diff --git a/py-polars/polars/internals/expr/expr.py b/py-polars/polars/internals/expr/expr.py index b6f5a50df405..37adb43ca10c 100644 --- a/py-polars/polars/internals/expr/expr.py +++ b/py-polars/polars/internals/expr/expr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import math import os import random @@ -39,6 +40,9 @@ sphinx_accessor, ) +with contextlib.suppress(ImportError): # Module not available when building docs + from polars.polars import arg_where as py_arg_where + if TYPE_CHECKING: import sys @@ -418,6 +422,37 @@ def all(self) -> Self: """ return self._from_pyexpr(self._pyexpr.all()) + def arg_true(self) -> Self: + """ + Return indices where expression evaluates `True`. + + .. warning:: + Modifies number of rows returned, so will fail in combination with other + expressions. Use as only expression in `select` / `with_columns`. + + Examples + -------- + >>> df = pl.DataFrame({"a": [1, 1, 2, 1]}) + >>> df.select((pl.col("a") == 1).arg_true()) + shape: (3, 1) + ┌─────┐ + │ a │ + │ --- │ + │ u32 │ + ╞═════╡ + │ 0 │ + │ 1 │ + │ 3 │ + └─────┘ + + See Also + -------- + Series.arg_true : Return indices where Series is True + pl.arg_where + + """ + return self._from_pyexpr(py_arg_where(self._pyexpr)) + def sqrt(self) -> Self: """ Compute the square root of the elements. diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index e76cbb9f28d1..27eda85f65b6 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -2911,6 +2911,10 @@ def arg_where( 3 ] + See Also + -------- + Series.arg_true : Return indices where Series is True + """ if eager: if not isinstance(condition, pli.Series): diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 13d18b2c1770..60c42b58465c 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -700,3 +700,10 @@ def test_exclude_invalid_input(input: tuple[Any, ...]) -> None: df = pl.DataFrame(schema=["a", "b", "c"]) with pytest.raises(TypeError): df.select(pl.all().exclude(*input)) + + +def test_arg_true() -> None: + df = pl.DataFrame({"a": [1, 1, 2, 1]}) + res = df.select((pl.col("a") == 1).arg_true()) + expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)]) + assert_frame_equal(res, expected)