From c45274342dc00b022a4c03e35e5a330aaac98202 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 5 Jan 2023 09:51:59 +0100 Subject: [PATCH] fix(rust, python): block streaming on literal series/range --- .../src/physical_plan/streaming/convert.rs | 34 +++++++++++++++---- py-polars/tests/unit/test_streaming.py | 9 +++++ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/polars/polars-lazy/src/physical_plan/streaming/convert.rs b/polars/polars-lazy/src/physical_plan/streaming/convert.rs index 516732e4a742..8da154dc0515 100644 --- a/polars/polars-lazy/src/physical_plan/streaming/convert.rs +++ b/polars/polars-lazy/src/physical_plan/streaming/convert.rs @@ -41,17 +41,37 @@ fn to_physical_piped_expr( } fn is_streamable(node: Node, expr_arena: &Arena) -> bool { - expr_arena.iter(node).all(|(_, ae)| match ae { + // check weather leaf colum is Col or Lit + let mut seen_column = false; + let mut seen_lit_range = false; + let all = expr_arena.iter(node).all(|(_, ae)| match ae { AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { matches!(options.collect_groups, ApplyOptions::ApplyFlat) } - AExpr::Column(_) - | AExpr::Literal(_) - | AExpr::BinaryExpr { .. } - | AExpr::Alias(_, _) - | AExpr::Cast { .. } => true, + AExpr::Column(_) => { + seen_column = true; + true + } + AExpr::BinaryExpr { .. } | AExpr::Alias(_, _) | AExpr::Cast { .. } => true, + AExpr::Literal(lv) => match lv { + LiteralValue::Series(_) | LiteralValue::Range { .. } => { + seen_lit_range = true; + true + } + _ => true, + }, _ => false, - }) + }); + + if all { + // adding a range or literal series to chunks will fail because sizes don't match + // if column is a leaf column then it is ok + // - so we want to block `with_column(lit(Series))` + // - but we want to allow `with_column(col("foo").is_in(Series))` + // that means that IFF we seen a lit_range, we only allow if we also seen a `column`. + return if seen_lit_range { seen_column } else { true }; + } + false } fn all_streamable(exprs: &[Node], expr_arena: &Arena) -> bool { diff --git a/py-polars/tests/unit/test_streaming.py b/py-polars/tests/unit/test_streaming.py index 37d369ffa153..0a1e942dac77 100644 --- a/py-polars/tests/unit/test_streaming.py +++ b/py-polars/tests/unit/test_streaming.py @@ -188,3 +188,12 @@ def test_streaming_categoricals_5921() -> None: for out in [out_eager, out_lazy]: assert out.dtypes == [pl.Categorical, pl.Int64] assert out.to_dict(False) == {"X": ["a", "b"], "Y": [2, 1]} + + +def test_streaming_block_on_literals_6054() -> None: + df = pl.DataFrame({"col_1": [0] * 5 + [1] * 5}) + s = pl.Series("col_2", list(range(10))) + + assert df.lazy().with_column(s).groupby("col_1").agg(pl.all().first()).collect( + streaming=True + ).sort("col_1").to_dict(False) == {"col_1": [0, 1], "col_2": [0, 5]}