Skip to content

Commit

Permalink
feat(python): Support Numpy ufunc with more than one expression (#7924)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj authored Apr 2, 2023
1 parent ddd88ef commit a7d3895
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
14 changes: 9 additions & 5 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import warnings
from datetime import timedelta
from functools import reduce
from functools import partial, reduce
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -47,7 +47,7 @@

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import arg_where as py_arg_where

from polars.polars import reduce as pyreduce
if TYPE_CHECKING:
import sys

Expand Down Expand Up @@ -282,9 +282,13 @@ def __array_ufunc__(
"""Numpy universal functions."""
num_expr = sum(isinstance(inp, Expr) for inp in inputs)
if num_expr > 1:
raise ValueError(
f"Numpy ufunc can only be used with one expression, {num_expr} given. Use `pl.reduce` to call numpy functions over multiple expressions."
)
if num_expr < len(inputs):
raise ValueError(
"Numpy ufunc with more than one expression can only be used if all non-expression inputs are provided as keyword arguments only"
)

exprs = selection_to_pyexpr_list(inputs)
return self._from_pyexpr(pyreduce(partial(ufunc, **kwargs), exprs))

def function(s: Series) -> Series: # pragma: no cover
args = [inp if not isinstance(inp, Expr) else s for inp in inputs]
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,41 @@ def test_ufunc_expr_not_first() -> None:
assert_frame_equal(out, expected)


def test_ufunc_multiple_expressions() -> None:
# example from https://github.com/pola-rs/polars/issues/6770
df = pl.DataFrame(
{
"v": [
-4.293,
-2.4659,
-1.8378,
-0.2821,
-4.5649,
-3.8128,
-7.4274,
3.3443,
3.8604,
-4.2200,
],
"u": [
-11.2268,
6.3478,
7.1681,
3.4986,
2.7320,
-1.0695,
-10.1408,
11.2327,
6.6623,
-8.1412,
],
}
)
expected = np.arctan2(df.get_column("v"), df.get_column("u"))
result = df.select(np.arctan2(pl.col("v"), pl.col("u")))[:, 0] # type: ignore[call-overload]
assert_series_equal(expected, result) # type: ignore[arg-type]


def test_window_deadlock() -> None:
np.random.seed(12)

Expand Down

0 comments on commit a7d3895

Please sign in to comment.