diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index cd592219f684..8ef8954848aa 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -38,24 +38,13 @@ macro_rules! impl_ternary_broadcast { .collect_trusted(); val.rename($self.name()); Ok(val) - } - (_, 1, 1) => { - let right = $other.get(0); - let mask = $mask.get(0).unwrap_or(false); - let mut val: ChunkedArray<$ty> = $self - .into_iter() - .map(|left| ternary_apply(mask, left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - } - (1, _, 1) => { - let left = $self.get(0); - let mask = $mask.get(0).unwrap_or(false); - let mut val: ChunkedArray<$ty> = $other - .into_iter() - .map(|right| ternary_apply(mask, left, right)) - .collect_trusted(); + }, + (_, _, 1) => { + let mut val: ChunkedArray<$ty> = if let Some(true) = $mask.get(0) { + $self.clone() + } else { + $other.clone() + }; val.rename($self.name()); Ok(val) }, @@ -79,16 +68,6 @@ macro_rules! impl_ternary_broadcast { val.rename($self.name()); Ok(val) }, - (l_len, r_len, 1) if l_len == r_len => { - let mask = $mask.get(0).unwrap_or(false); - let mut val: ChunkedArray<$ty> = $self - .into_iter() - .zip($other) - .map(|(left, right)| ternary_apply(mask, left, right)) - .collect_trusted(); - val.rename($self.name()); - Ok(val) - }, (_, _, _) => Err(polars_err!( ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation" )), diff --git a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs index 5fc62c946b1c..6faf6eb85bf1 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -47,12 +47,22 @@ fn expand_lengths(truthy: &mut Series, falsy: &mut Series, mask: &mut BooleanChu // Mask length 1 will broadcast to the matching branch. let len = match mask.get(0) { Some(true) => { - *falsy = truthy.clone(); - truthy.len() + let len = truthy.len(); + *falsy = if falsy.len() < len { + Series::full_null(falsy.name(), len, falsy.dtype()) + } else { + falsy.slice(0, len) + }; + len }, _ => { - *truthy = falsy.clone(); - falsy.len() + let len = falsy.len(); + *truthy = if truthy.len() < len { + Series::full_null(truthy.name(), len, truthy.dtype()) + } else { + truthy.slice(0, len) + }; + len }, }; @@ -141,7 +151,7 @@ impl PhysicalExpr for TernaryExpr { let mut truthy = truthy?; let mut falsy = falsy?; - expand_lengths(&mut truthy, &mut falsy, &mut mask); + // expand_lengths(&mut truthy, &mut falsy, &mut mask); truthy.zip_with(&mask, &falsy) } @@ -168,165 +178,171 @@ impl PhysicalExpr for TernaryExpr { }; let ac_mask = ac_mask?; - let mut ac_truthy = ac_truthy?; - let mut ac_falsy = ac_falsy?; - - let mask_s = ac_mask.flat_naive(); - - // BIG TODO: find which branches are never hit and remove them. - use AggState::*; - match (ac_truthy.agg_state(), ac_falsy.agg_state()) { - // All branches are aggregated-flat or literal - // mask -> aggregated-flat - // truthy -> aggregated-flat | literal - // falsy -> aggregated-flat | literal - // simply align lengths and zip - ( - Literal(truthy) | AggregatedScalar(truthy), - AggregatedScalar(falsy) | Literal(falsy), - ) - | (AggregatedList(truthy), AggregatedList(falsy)) - if matches!(ac_mask.agg_state(), AggState::AggregatedScalar(_)) => - { - let mut truthy = truthy.clone(); - let mut falsy = falsy.clone(); - let mut mask = ac_mask.series().bool()?.clone(); - expand_lengths(&mut truthy, &mut falsy, &mut mask); - let out = truthy.zip_with(&mask, &falsy).unwrap(); - ac_truthy.with_series(out.with_name(truthy.name()), true, Some(&self.expr))?; - Ok(ac_truthy) - }, - - // We cannot flatten a list because that changes the order, so we apply over groups. - (AggregatedList(_), NotAggregated(_)) | (NotAggregated(_), AggregatedList(_)) => { - finish_as_iters(ac_truthy, ac_falsy, ac_mask) - }, - - // Then: - // col().shift() - // Otherwise: - // None - (AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => { - if !aggregation_predicate { - return finish_as_iters(ac_truthy, ac_falsy, ac_mask); - } - let mask = mask_s.bool()?; - let check_length = |ca: &ListChunked, mask: &BooleanChunked| { - polars_ensure!( - ca.len() == mask.len(), expr = self.expr, ComputeError: - "predicates length: {} does not match groups length: {}", - mask.len(), ca.len() - ); - Ok(()) - }; - - if ac_falsy.is_literal() && self.falsy.as_expression().map(has_null) == Some(true) { - let s = ac_truthy.aggregated(); - let ca = s.list().unwrap(); - check_length(ca, mask)?; - let out = ca - .into_iter() - .zip(mask) - .map(|(truthy, take)| if take? { truthy } else { None }) - .collect_trusted::() - .with_name(ac_truthy.series().name()); - ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; - Ok(ac_truthy) - } else if ac_truthy.is_literal() - && self.truthy.as_expression().map(has_null) == Some(true) - { - let s = ac_falsy.aggregated(); - let ca = s.list().unwrap(); - check_length(ca, mask)?; - let out = ca - .into_iter() - .zip(mask) - .map(|(falsy, take)| if take? { None } else { falsy }) - .collect_trusted::() - .with_name(ac_truthy.series().name()); - ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; - Ok(ac_truthy) - } - // Then: - // col().shift() - // Otherwise: - // lit(list) - else if ac_truthy.is_literal() { - let literal = ac_truthy.series(); - let s = ac_falsy.aggregated(); - let ca = s.list().unwrap(); - check_length(ca, mask)?; - let out = ca - .into_iter() - .zip(mask) - .map(|(falsy, take)| if take? { Some(literal.clone()) } else { falsy }) - .collect_trusted::() - .with_name(ac_truthy.series().name()); - ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; - Ok(ac_truthy) - } else { - let literal = ac_falsy.series(); - let s = ac_truthy.aggregated(); - let ca = s.list().unwrap(); - check_length(ca, mask)?; - let out = ca - .into_iter() - .zip(mask) - .map(|(truthy, take)| if take? { truthy } else { Some(literal.clone()) }) - .collect_trusted::() - .with_name(ac_truthy.series().name()); - ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; - Ok(ac_truthy) - } - }, - // Both are or a flat series or aggregated into a list - // so we can flatten the Series an apply the operators. - _ => { - // Inspect the predicate and if it is consisting - // of arity/binary and some aggregation we apply as iters as - // it gets complicated quickly. - // For instance: - // when(col(..) > min(..)).then(..).otherwise(..) - if let Some(expr) = self.predicate.as_expression() { - let mut has_arity = false; - let mut has_agg = false; - for e in expr.into_iter() { - match e { - Expr::BinaryExpr { .. } | Expr::Ternary { .. } => has_arity = true, - Expr::Agg(_) => has_agg = true, - Expr::Function { options, .. } - | Expr::AnonymousFunction { options, .. } - if options.is_groups_sensitive() => - { - has_agg = true - }, - _ => {}, - } - } - if has_arity && has_agg { - return finish_as_iters(ac_truthy, ac_falsy, ac_mask); - } - } - - if !aggregation_predicate { - return finish_as_iters(ac_truthy, ac_falsy, ac_mask); - } - let mut mask = mask_s.bool()?.clone(); - let mut truthy = ac_truthy.flat_naive().into_owned(); - let mut falsy = ac_falsy.flat_naive().into_owned(); - expand_lengths(&mut truthy, &mut falsy, &mut mask); - let out = truthy.zip_with(&mask, &falsy)?; - - // Because of the flattening we don't have to do that anymore. - if matches!(ac_truthy.update_groups, UpdateGroups::WithSeriesLen) { - ac_truthy.with_update_groups(UpdateGroups::No); - } - - ac_truthy.with_series(out, false, None)?; - - Ok(ac_truthy) - }, - } + let ac_truthy = ac_truthy?; + let ac_falsy = ac_falsy?; + + return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + + // let ac_mask = ac_mask?; + // let mut ac_truthy = ac_truthy?; + // let mut ac_falsy = ac_falsy?; + + // let mask_s = ac_mask.flat_naive(); + + // // BIG TODO: find which branches are never hit and remove them. + // use AggState::*; + // match (ac_truthy.agg_state(), ac_falsy.agg_state()) { + // // All branches are aggregated-flat or literal + // // mask -> aggregated-flat + // // truthy -> aggregated-flat | literal + // // falsy -> aggregated-flat | literal + // // simply align lengths and zip + // ( + // Literal(truthy) | AggregatedScalar(truthy), + // AggregatedScalar(falsy) | Literal(falsy), + // ) + // | (AggregatedList(truthy), AggregatedList(falsy)) + // if matches!(ac_mask.agg_state(), AggState::AggregatedScalar(_)) => + // { + // let mut truthy = truthy.clone(); + // let mut falsy = falsy.clone(); + // let mut mask = ac_mask.series().bool()?.clone(); + // expand_lengths(&mut truthy, &mut falsy, &mut mask); + // let out = truthy.zip_with(&mask, &falsy).unwrap(); + // ac_truthy.with_series(out.with_name(truthy.name()), true, Some(&self.expr))?; + // Ok(ac_truthy) + // }, + + // // We cannot flatten a list because that changes the order, so we apply over groups. + // (AggregatedList(_), NotAggregated(_)) | (NotAggregated(_), AggregatedList(_)) => { + // finish_as_iters(ac_truthy, ac_falsy, ac_mask) + // }, + + // // Then: + // // col().shift() + // // Otherwise: + // // None + // (AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => { + // if !aggregation_predicate { + // return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + // } + // let mask = mask_s.bool()?; + // let check_length = |ca: &ListChunked, mask: &BooleanChunked| { + // polars_ensure!( + // ca.len() == mask.len(), expr = self.expr, ComputeError: + // "predicates length: {} does not match groups length: {}", + // mask.len(), ca.len() + // ); + // Ok(()) + // }; + + // if ac_falsy.is_literal() && self.falsy.as_expression().map(has_null) == Some(true) { + // let s = ac_truthy.aggregated(); + // let ca = s.list().unwrap(); + // check_length(ca, mask)?; + // let out = ca + // .into_iter() + // .zip(mask) + // .map(|(truthy, take)| if take? { truthy } else { None }) + // .collect_trusted::() + // .with_name(ac_truthy.series().name()); + // ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; + // Ok(ac_truthy) + // } else if ac_truthy.is_literal() + // && self.truthy.as_expression().map(has_null) == Some(true) + // { + // let s = ac_falsy.aggregated(); + // let ca = s.list().unwrap(); + // check_length(ca, mask)?; + // let out = ca + // .into_iter() + // .zip(mask) + // .map(|(falsy, take)| if take? { None } else { falsy }) + // .collect_trusted::() + // .with_name(ac_truthy.series().name()); + // ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; + // Ok(ac_truthy) + // } + // // Then: + // // col().shift() + // // Otherwise: + // // lit(list) + // else if ac_truthy.is_literal() { + // let literal = ac_truthy.series(); + // let s = ac_falsy.aggregated(); + // let ca = s.list().unwrap(); + // check_length(ca, mask)?; + // let out = ca + // .into_iter() + // .zip(mask) + // .map(|(falsy, take)| if take? { Some(literal.clone()) } else { falsy }) + // .collect_trusted::() + // .with_name(ac_truthy.series().name()); + // ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; + // Ok(ac_truthy) + // } else { + // let literal = ac_falsy.series(); + // let s = ac_truthy.aggregated(); + // let ca = s.list().unwrap(); + // check_length(ca, mask)?; + // let out = ca + // .into_iter() + // .zip(mask) + // .map(|(truthy, take)| if take? { truthy } else { Some(literal.clone()) }) + // .collect_trusted::() + // .with_name(ac_truthy.series().name()); + // ac_truthy.with_series(out.into_series(), true, Some(&self.expr))?; + // Ok(ac_truthy) + // } + // }, + // // Both are or a flat series or aggregated into a list + // // so we can flatten the Series an apply the operators. + // _ => { + // // Inspect the predicate and if it is consisting + // // of arity/binary and some aggregation we apply as iters as + // // it gets complicated quickly. + // // For instance: + // // when(col(..) > min(..)).then(..).otherwise(..) + // if let Some(expr) = self.predicate.as_expression() { + // let mut has_arity = false; + // let mut has_agg = false; + // for e in expr.into_iter() { + // match e { + // Expr::BinaryExpr { .. } | Expr::Ternary { .. } => has_arity = true, + // Expr::Agg(_) => has_agg = true, + // Expr::Function { options, .. } + // | Expr::AnonymousFunction { options, .. } + // if options.is_groups_sensitive() => + // { + // has_agg = true + // }, + // _ => {}, + // } + // } + // if has_arity && has_agg { + // return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + // } + // } + + // if !aggregation_predicate { + // return finish_as_iters(ac_truthy, ac_falsy, ac_mask); + // } + // let mut mask = mask_s.bool()?.clone(); + // let mut truthy = ac_truthy.flat_naive().into_owned(); + // let mut falsy = ac_falsy.flat_naive().into_owned(); + // expand_lengths(&mut truthy, &mut falsy, &mut mask); + // let out = truthy.zip_with(&mask, &falsy)?; + + // // Because of the flattening we don't have to do that anymore. + // if matches!(ac_truthy.update_groups, UpdateGroups::WithSeriesLen) { + // ac_truthy.with_update_groups(UpdateGroups::No); + // } + + // ac_truthy.with_series(out, false, None)?; + + // Ok(ac_truthy) + // }, + // } } fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { Some(self) @@ -353,7 +369,7 @@ impl PartitionedAggregation for TernaryExpr { let mask = mask.evaluate_partitioned(df, groups, state)?; let mut mask = mask.bool()?.clone(); - expand_lengths(&mut truthy, &mut falsy, &mut mask); + // expand_lengths(&mut truthy, &mut falsy, &mut mask); truthy.zip_with(&mask, &falsy) } diff --git a/py-polars/tests/unit/functions/test_whenthen.py b/py-polars/tests/unit/functions/test_whenthen.py index af6e4c35e2df..b023638b20f2 100644 --- a/py-polars/tests/unit/functions/test_whenthen.py +++ b/py-polars/tests/unit/functions/test_whenthen.py @@ -280,6 +280,9 @@ def test_broadcast_zero_len_12354() -> None: for out_true, out_false in ( (pl.col("x").head(0), pl.col("x")), (pl.col("x"), pl.col("x").head(0)), + # If the len of the non-matching branch >= the matching branch, + # a `slice` operation is used instead of allocating a new series. + (pl.col("x"), pl.col("x")), ): for predicate, expected in ( # true @@ -295,3 +298,22 @@ def test_broadcast_zero_len_12354() -> None: # should not panic on NULL 1-length predicate assert pl.select(pl.when(pl.lit(None, dtype=pl.Boolean)).then(1)).item() is None + + +def test_when_then_output_name_12380() -> None: + df = pl.DataFrame( + {"x": range(5), "y": range(5, 10)}, schema={"x": pl.Int8, "y": pl.Int64} + ).with_columns(true=True, false=False) + + expect = df.select(pl.col("x").cast(pl.Int64)) + for true_expr in (pl.first("true"), pl.col("true"), pl.lit(True)): + assert_frame_equal( + expect, + df.select(pl.when(true_expr).then(pl.col("x")).otherwise(pl.col("y"))), + ) + expect = df.select(pl.col("y").alias("x")) + for false_expr in (pl.first("false"), pl.col("false"), pl.lit(False)): + assert_frame_equal( + expect, + df.select(pl.when(false_expr).then(pl.col("x")).otherwise(pl.col("y"))), + )