From 25c5cdc2f18fd55dafccf121d836eb74fdc91f9a Mon Sep 17 00:00:00 2001 From: J van Zundert Date: Sun, 5 Feb 2023 07:14:32 +0000 Subject: [PATCH] fix(python): Support numpy ufunc when expression not first arg (#6675) --- py-polars/polars/internals/expr/expr.py | 4 ++-- py-polars/tests/unit/test_lazy.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/internals/expr/expr.py b/py-polars/polars/internals/expr/expr.py index 0d77e664d072..a64d08708c08 100644 --- a/py-polars/polars/internals/expr/expr.py +++ b/py-polars/polars/internals/expr/expr.py @@ -288,10 +288,10 @@ def __array_ufunc__( self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any ) -> Expr: """Numpy universal functions.""" - args = [inp for inp in inputs if not isinstance(inp, Expr)] def function(s: pli.Series) -> pli.Series: # pragma: no cover - return ufunc(s, *args, **kwargs) + args = [inp if not isinstance(inp, Expr) else s for inp in inputs] + return ufunc(*args, **kwargs) return self.map(function) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 8ab32aced8bb..d64f88e95ff8 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -965,6 +965,26 @@ def test_ufunc() -> None: assert out.dtypes == expected.dtypes +def test_ufunc_expr_not_first() -> None: + """Check numpy ufunc expressions also work if expression not the first argument.""" + df = pl.DataFrame([pl.Series("a", [1, 2, 3], dtype=pl.Float64)]) + out = df.select( + [ + np.power(2.0, cast(Any, pl.col("a"))).alias("power"), + (2.0 / cast(Any, pl.col("a"))).alias("divide_scalar"), + (np.array([2, 2, 2]) / cast(Any, pl.col("a"))).alias("divide_array"), + ] + ) + expected = pl.DataFrame( + [ + pl.Series("power", [2**1, 2**2, 2**3], dtype=pl.Float64), + pl.Series("divide_scalar", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + pl.Series("divide_array", [2 / 1, 2 / 2, 2 / 3], dtype=pl.Float64), + ] + ) + assert_frame_equal(out, expected) + + def test_clip() -> None: df = pl.DataFrame({"a": [1, 2, 3, 4, 5]}) assert df.select(pl.col("a").clip(2, 4))["a"].to_list() == [2, 2, 3, 4, 4]