Skip to content

Commit

Permalink
feat(python): Add Expr.arg_true (#7056)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj authored Mar 1, 2023
1 parent 59536fc commit 00cd55b
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Manipulation/selection
Expr.append
Expr.arg_sort
Expr.argsort
Expr.arg_true
Expr.backward_fill
Expr.cast
Expr.ceil
Expand Down
35 changes: 35 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import math
import os
import random
Expand Down Expand Up @@ -39,6 +40,9 @@
sphinx_accessor,
)

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import arg_where as py_arg_where

if TYPE_CHECKING:
import sys

Expand Down Expand Up @@ -418,6 +422,37 @@ def all(self) -> Self:
"""
return self._from_pyexpr(self._pyexpr.all())

def arg_true(self) -> Self:
"""
Return indices where expression evaluates `True`.
.. warning::
Modifies number of rows returned, so will fail in combination with other
expressions. Use as only expression in `select` / `with_columns`.
Examples
--------
>>> df = pl.DataFrame({"a": [1, 1, 2, 1]})
>>> df.select((pl.col("a") == 1).arg_true())
shape: (3, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 0 │
│ 1 │
│ 3 │
└─────┘
See Also
--------
Series.arg_true : Return indices where Series is True
pl.arg_where
"""
return self._from_pyexpr(py_arg_where(self._pyexpr))

def sqrt(self) -> Self:
"""
Compute the square root of the elements.
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2911,6 +2911,10 @@ def arg_where(
3
]
See Also
--------
Series.arg_true : Return indices where Series is True
"""
if eager:
if not isinstance(condition, pli.Series):
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,10 @@ def test_exclude_invalid_input(input: tuple[Any, ...]) -> None:
df = pl.DataFrame(schema=["a", "b", "c"])
with pytest.raises(TypeError):
df.select(pl.all().exclude(*input))


def test_arg_true() -> None:
df = pl.DataFrame({"a": [1, 1, 2, 1]})
res = df.select((pl.col("a") == 1).arg_true())
expected = pl.DataFrame([pl.Series("a", [0, 1, 3], dtype=pl.UInt32)])
assert_frame_equal(res, expected)

0 comments on commit 00cd55b

Please sign in to comment.