From c22fb5ce10a3459609eda8730f0d8f7b92d757b2 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 25 Oct 2022 12:10:29 +0200 Subject: [PATCH] let caller decide --- polars/polars-lazy/polars-plan/src/dsl/mod.rs | 9 +++- polars/polars-lazy/src/tests/queries.rs | 40 ------------------ polars/tests/it/lazy/queries.rs | 41 +++++++++++++++++++ py-polars/polars/internals/lazy_functions.py | 13 +++++- py-polars/src/lazy/apply.rs | 3 +- py-polars/src/lib.rs | 10 ++++- 6 files changed, 70 insertions(+), 46 deletions(-) diff --git a/polars/polars-lazy/polars-plan/src/dsl/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/mod.rs index cc9b6dd6a206..9554389f4468 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/mod.rs @@ -2373,7 +2373,12 @@ where /// /// * `[map_mul]` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power` /// * `[apply_mul]` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc. -pub fn apply_multiple(function: F, expr: E, output_type: GetOutput) -> Expr +pub fn apply_multiple( + function: F, + expr: E, + output_type: GetOutput, + returns_scalar: bool, +) -> Expr where F: Fn(&mut [Series]) -> PolarsResult + 'static + Send + Sync, E: AsRef<[Expr]>, @@ -2388,7 +2393,7 @@ where collect_groups: ApplyOptions::ApplyGroups, // don't set this to true // this is for the caller to decide - auto_explode: false, + auto_explode: returns_scalar, fmt_str: "", ..Default::default() }, diff --git a/polars/polars-lazy/src/tests/queries.rs b/polars/polars-lazy/src/tests/queries.rs index 84763e63da29..887097866c42 100644 --- a/polars/polars-lazy/src/tests/queries.rs +++ b/polars/polars-lazy/src/tests/queries.rs @@ -1674,46 +1674,6 @@ fn test_groupby_rank() -> PolarsResult<()> { Ok(()) } -#[test] -fn test_apply_multiple_columns() -> PolarsResult<()> { - let df = fruits_cars(); - - let multiply = |s: &mut [Series]| Ok(&(&s[0] * &s[0]) * &s[1]); - - let out = df - .clone() - .lazy() - .select([map_multiple( - multiply, - [col("A"), col("B")], - GetOutput::from_type(DataType::Float64), - )]) - .collect()?; - let out = out.column("A")?; - let out = out.i32()?; - assert_eq!( - Vec::from(out), - &[Some(5), Some(16), Some(27), Some(32), Some(25)] - ); - - let out = df - .lazy() - .groupby_stable([col("cars")]) - .agg([apply_multiple( - multiply, - [col("A"), col("B")], - GetOutput::from_type(DataType::Float64), - )]) - .collect()?; - - let out = out.column("A")?; - let out = out.list()?.get(1).unwrap(); - let out = out.i32()?; - - assert_eq!(Vec::from(out), &[Some(16)]); - Ok(()) -} - #[test] pub fn test_select_by_dtypes() -> PolarsResult<()> { let df = df![ diff --git a/polars/tests/it/lazy/queries.rs b/polars/tests/it/lazy/queries.rs index c4570e924a50..e9b36a7ec866 100644 --- a/polars/tests/it/lazy/queries.rs +++ b/polars/tests/it/lazy/queries.rs @@ -186,3 +186,44 @@ fn test_unknown_supertype_ignore() -> PolarsResult<()> { assert_eq!(out.shape(), (4, 2)); Ok(()) } + +#[test] +fn test_apply_multiple_columns() -> PolarsResult<()> { + let df = fruits_cars(); + + let multiply = |s: &mut [Series]| Ok(&(&s[0] * &s[0]) * &s[1]); + + let out = df + .clone() + .lazy() + .select([map_multiple( + multiply, + [col("A"), col("B")], + GetOutput::from_type(DataType::Float64), + )]) + .collect()?; + let out = out.column("A")?; + let out = out.i32()?; + assert_eq!( + Vec::from(out), + &[Some(5), Some(16), Some(27), Some(32), Some(25)] + ); + + let out = df + .lazy() + .groupby_stable([col("cars")]) + .agg([apply_multiple( + multiply, + [col("A"), col("B")], + GetOutput::from_type(DataType::Float64), + true, + )]) + .collect()?; + + let out = out.column("A")?; + let out = out.list()?.get(1).unwrap(); + let out = out.i32()?; + + assert_eq!(Vec::from(out), &[Some(16)]); + Ok(()) +} diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index 031619423dc2..c41757e16c4d 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -968,13 +968,16 @@ def map( """ exprs = pli.selection_to_pyexpr_list(exprs) - return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=False)) + return pli.wrap_expr( + _map_mul(exprs, f, return_dtype, apply_groups=False, returns_scalar=False) + ) def apply( exprs: Sequence[str | pli.Expr], f: Callable[[Sequence[pli.Series]], pli.Series | Any], return_dtype: type[DataType] | None = None, + returns_scalar: bool = True, ) -> pli.Expr: """ Apply a custom/user-defined function (UDF) in a GroupBy context. @@ -995,6 +998,8 @@ def apply( Function to apply over the input return_dtype dtype of the output Series + returns_scalar + If the function returns a single scalar as output. Returns ------- @@ -1002,7 +1007,11 @@ def apply( """ exprs = pli.selection_to_pyexpr_list(exprs) - return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=True)) + return pli.wrap_expr( + _map_mul( + exprs, f, return_dtype, apply_groups=True, returns_scalar=returns_scalar + ) + ) def fold( diff --git a/py-polars/src/lazy/apply.rs b/py-polars/src/lazy/apply.rs index d65e21a7956d..db0d32b1cc53 100644 --- a/py-polars/src/lazy/apply.rs +++ b/py-polars/src/lazy/apply.rs @@ -193,6 +193,7 @@ pub fn map_mul( lambda: PyObject, output_type: &PyAny, apply_groups: bool, + returns_scalar: bool, ) -> PyExpr { let output_type = get_output_type(output_type); @@ -221,7 +222,7 @@ pub fn map_mul( None => fld.clone(), }); if apply_groups { - polars::lazy::dsl::apply_multiple(function, exprs, output_map).into() + polars::lazy::dsl::apply_multiple(function, exprs, output_map, returns_scalar).into() } else { polars::lazy::dsl::map_multiple(function, exprs, output_map).into() } diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index ed216fe7e6b5..6f3cdac6a5a1 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -442,8 +442,16 @@ pub fn map_mul( lambda: PyObject, output_type: &PyAny, apply_groups: bool, + returns_scalar: bool, ) -> PyExpr { - lazy::map_mul(&pyexpr, py, lambda, output_type, apply_groups) + lazy::map_mul( + &pyexpr, + py, + lambda, + output_type, + apply_groups, + returns_scalar, + ) } #[pyfunction]