From be3bb253515b1a56dd82d6566739c7af02f0dc8b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 12 Feb 2024 10:21:11 -0800 Subject: [PATCH] feat: apply negate in simplify expression pass (#14436) --- .../logical_plan/optimizer/simplify_expr.rs | 27 +++++++++++++++++++ py-polars/tests/unit/test_arity.py | 22 +++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index c7de180a84ce..4776bf9f5993 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -229,12 +229,39 @@ impl OptimizationRule for SimplifyBooleanRule { { Some(AExpr::Literal(LiteralValue::Boolean(true))) }, + AExpr::Function { + input, + function: FunctionExpr::Negate, + .. + } if input.len() == 1 => { + let input = input[0]; + let ae = expr_arena.get(input); + eval_negate(ae) + }, _ => None, }; Ok(out) } } +fn eval_negate(ae: &AExpr) -> Option { + let out = match ae { + AExpr::Literal(lv) => match lv { + #[cfg(feature = "dtype-i8")] + LiteralValue::Int8(v) => LiteralValue::Int8(-*v), + #[cfg(feature = "dtype-i16")] + LiteralValue::Int16(v) => LiteralValue::Int16(-*v), + LiteralValue::Int32(v) => LiteralValue::Int32(-*v), + LiteralValue::Int64(v) => LiteralValue::Int64(-*v), + LiteralValue::Float32(v) => LiteralValue::Float32(-*v), + LiteralValue::Float64(v) => LiteralValue::Float64(-*v), + _ => return None, + }, + _ => return None, + }; + Some(AExpr::Literal(out)) +} + fn eval_bitwise(left: &AExpr, right: &AExpr, operation: F) -> Option where F: Fn(bool, bool) -> bool, diff --git a/py-polars/tests/unit/test_arity.py b/py-polars/tests/unit/test_arity.py index 4be2a1910fe9..ea62e6583cae 100644 --- a/py-polars/tests/unit/test_arity.py +++ b/py-polars/tests/unit/test_arity.py @@ -79,3 +79,25 @@ def test_broadcast_string_ops_12632( assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3 assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3 assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3 + + +def test_negate_inlined_14278() -> None: + df = pl.DataFrame( + {"group": ["A", "A", "B", "B", "B", "C", "C"], "value": [1, 2, 3, 4, 5, 6, 7]} + ) + + agg_expr = [ + pl.struct("group", "value").tail(2).alias("list"), + pl.col("value").sort().tail(2).count().alias("count"), + ] + + q = df.lazy().group_by("group").agg(agg_expr) + assert q.collect().sort("group").to_dict(as_series=False) == { + "group": ["A", "B", "C"], + "list": [ + [{"group": "A", "value": 1}, {"group": "A", "value": 2}], + [{"group": "B", "value": 4}, {"group": "B", "value": 5}], + [{"group": "C", "value": 6}, {"group": "C", "value": 7}], + ], + "count": [2, 2, 2], + }