Skip to content

Commit

Permalink
fix: fix cse bug when window function is nested (#14070)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 29, 2024
1 parent 306be3c commit b61d20c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
3 changes: 3 additions & 0 deletions crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ impl Visitor for ExprIdentifierVisitor<'_> {

fn pre_visit(&mut self, node: &Self::Node) -> PolarsResult<VisitRecursion> {
if skip_pre_visit(node.to_aexpr(), self.is_group_by) {
// Still add to the stack so that a parent becomes invalidated.
self.visit_stack
.push(VisitRecord::SubExprId(Identifier::new(), false));
return Ok(VisitRecursion::Skip);
}

Expand Down
52 changes: 51 additions & 1 deletion py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from datetime import date, datetime
import typing
from datetime import date, datetime, timedelta
from tempfile import NamedTemporaryFile
from typing import Any

Expand Down Expand Up @@ -594,3 +595,52 @@ def test_cse_11958() -> None:
"diff3": [None, None, None, 30, 30],
"diff4": [None, None, None, None, 40],
}


@typing.no_type_check
def test_cse_14047() -> None:
ldf = pl.LazyFrame(
{
"timestamp": pl.datetime_range(
datetime(2024, 1, 12),
datetime(2024, 1, 12, 0, 0, 0, 150_000),
"10ms",
eager=True,
closed="left",
),
"price": list(range(15)),
}
)

def count_diff(
price: pl.Expr, upper_bound: float = 0.1, lower_bound: float = 0.001
):
span_end_to_curr = (
price.count()
.cast(int)
.rolling("timestamp", period=timedelta(seconds=lower_bound))
)
span_start_to_curr = (
price.count()
.cast(int)
.rolling("timestamp", period=timedelta(seconds=upper_bound))
)
return (span_start_to_curr - span_end_to_curr).alias(
f"count_diff_{upper_bound}_{lower_bound}"
)

def s_per_count(count_diff, span) -> pl.Expr:
return (span[1] * 1000 - span[0] * 1000) / count_diff

spans = [(0.001, 0.1), (1, 10)]
count_diff_exprs = [count_diff(pl.col("price"), span[0], span[1]) for span in spans]
s_per_count_exprs = [
s_per_count(count_diff, span).alias(f"zz_{span}")
for count_diff, span in zip(count_diff_exprs, spans)
]

exprs = count_diff_exprs + s_per_count_exprs
ldf = ldf.with_columns(*exprs)
assert_frame_equal(
ldf.collect(comm_subexpr_elim=True), ldf.collect(comm_subexpr_elim=False)
)

0 comments on commit b61d20c

Please sign in to comment.