Skip to content

Commit

Permalink
fix(python): Raise ValueError on passing multiple expressions Numpy u…
Browse files Browse the repository at this point in the history
…func (#6821)
  • Loading branch information
zundertj authored Feb 12, 2023
1 parent 5436991 commit 23c8c9d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ def __array_ufunc__(
) -> Expr:
"""Numpy universal functions."""

num_expr = sum(isinstance(inp, Expr) for inp in inputs)
if num_expr > 1:
raise ValueError(
f"Numpy ufunc can only be used with one expression, {num_expr} given. Use `pl.reduce` to call numpy functions over multiple expressions."
)

def function(s: pli.Series) -> pli.Series: # pragma: no cover
args = [inp if not isinstance(inp, Expr) else s for inp in inputs]
return ufunc(*args, **kwargs)
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,20 @@ def test_ufunc_expr_not_first() -> None:
assert_frame_equal(out, expected)


def test_ufunc_multiple_expr() -> None:
df = pl.DataFrame(
[
pl.Series("a", [1, 2, 3], dtype=pl.Float64),
pl.Series("b", [4, 5, 6], dtype=pl.Float64),
]
)

with pytest.raises(
ValueError, match="Numpy ufunc can only be used with one expression, 2 given"
):
df.select(np.arctan2(pl.col("a"), pl.col("b"))) # type: ignore[call-overload]


def test_clip() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
assert df.select(pl.col("a").clip(2, 4))["a"].to_list() == [2, 2, 3, 4, 4]
Expand Down

0 comments on commit 23c8c9d

Please sign in to comment.