Skip to content

Commit

Permalink
feat: apply negate in simplify expression pass (#14436)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 12, 2024
1 parent f96773e commit be3bb25
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
27 changes: 27 additions & 0 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,39 @@ impl OptimizationRule for SimplifyBooleanRule {
{
Some(AExpr::Literal(LiteralValue::Boolean(true)))
},
AExpr::Function {
input,
function: FunctionExpr::Negate,
..
} if input.len() == 1 => {
let input = input[0];
let ae = expr_arena.get(input);
eval_negate(ae)
},
_ => None,
};
Ok(out)
}
}

fn eval_negate(ae: &AExpr) -> Option<AExpr> {
let out = match ae {
AExpr::Literal(lv) => match lv {
#[cfg(feature = "dtype-i8")]
LiteralValue::Int8(v) => LiteralValue::Int8(-*v),
#[cfg(feature = "dtype-i16")]
LiteralValue::Int16(v) => LiteralValue::Int16(-*v),
LiteralValue::Int32(v) => LiteralValue::Int32(-*v),
LiteralValue::Int64(v) => LiteralValue::Int64(-*v),
LiteralValue::Float32(v) => LiteralValue::Float32(-*v),
LiteralValue::Float64(v) => LiteralValue::Float64(-*v),
_ => return None,
},
_ => return None,
};
Some(AExpr::Literal(out))
}

fn eval_bitwise<F>(left: &AExpr, right: &AExpr, operation: F) -> Option<AExpr>
where
F: Fn(bool, bool) -> bool,
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_arity.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,25 @@ def test_broadcast_string_ops_12632(
assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3
assert df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3
assert df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3


def test_negate_inlined_14278() -> None:
df = pl.DataFrame(
{"group": ["A", "A", "B", "B", "B", "C", "C"], "value": [1, 2, 3, 4, 5, 6, 7]}
)

agg_expr = [
pl.struct("group", "value").tail(2).alias("list"),
pl.col("value").sort().tail(2).count().alias("count"),
]

q = df.lazy().group_by("group").agg(agg_expr)
assert q.collect().sort("group").to_dict(as_series=False) == {
"group": ["A", "B", "C"],
"list": [
[{"group": "A", "value": 1}, {"group": "A", "value": 2}],
[{"group": "B", "value": 4}, {"group": "B", "value": 5}],
[{"group": "C", "value": 6}, {"group": "C", "value": 7}],
],
"count": [2, 2, 2],
}

0 comments on commit be3bb25

Please sign in to comment.