From 9ec6108feea7ea095ed7d57e53d0d31371437b98 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 31 Jan 2023 15:53:41 +0100 Subject: [PATCH] fix(rust, python): raise error on string numeric arithmetic --- .../logical_plan/optimizer/simplify_expr.rs | 11 +- .../optimizer/type_coercion/binary.rs | 225 ++++++++++++++++++ .../mod.rs} | 174 +------------- polars/polars-utils/src/lib.rs | 1 + polars/polars-utils/src/macros.rs | 7 + py-polars/tests/unit/test_errors.py | 8 + 6 files changed, 254 insertions(+), 172 deletions(-) create mode 100644 polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs rename polars/polars-lazy/polars-plan/src/logical_plan/optimizer/{type_coercion.rs => type_coercion/mod.rs} (66%) create mode 100644 polars/polars-utils/src/macros.rs diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index 74b2448dfa31..45140068bb57 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -311,11 +311,18 @@ fn string_addition_to_linear_concat( let schema = lp_arena.get(input).schema(lp_arena); let get_type = |ae: &AExpr| ae.get_type(&schema, Context::Default, expr_arena).ok(); - let addition_type = get_type(left_aexpr) + let type_a = get_type(left_aexpr) .or_else(|| get_type(right_aexpr)) .unwrap(); + let type_b = get_type(right_aexpr) + .or_else(|| get_type(right_aexpr)) + .unwrap(); + + if type_a != type_b { + return None; + } - if addition_type == DataType::Utf8 { + if type_a == DataType::Utf8 { match (left_aexpr, right_aexpr) { // concat + concat ( diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs new file mode 100644 index 000000000000..c03f805e6295 --- /dev/null +++ b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/binary.rs @@ -0,0 +1,225 @@ +use polars_utils::matches_any_order; + +use super::*; + +macro_rules! unpack { + ($packed:expr) => {{ + match $packed { + Some(payload) => payload, + None => return Ok(None), + } + }}; +} + +#[allow(unused_variables)] +fn compares_cat_to_string(type_left: &DataType, type_right: &DataType, op: Operator) -> bool { + #[cfg(feature = "dtype-categorical")] + { + op.is_comparison() + && matches_any_order!( + type_left, + type_right, + DataType::Utf8, + DataType::Categorical(_) + ) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } +} + +#[allow(unused_variables)] +fn is_datetime_arithmetic(type_left: &DataType, type_right: &DataType, op: Operator) -> bool { + matches!(op, Operator::Minus | Operator::Plus) + && matches_any_order!( + &type_left, + &type_right, + DataType::Datetime(_, _) | DataType::Date, + DataType::Duration(_) + ) +} + +fn is_list_arithmetic(type_left: &DataType, type_right: &DataType, op: Operator) -> bool { + op.is_arithmetic() + && matches!( + (&type_left, &type_right), + (DataType::List(_), _) | (_, DataType::List(_)) + ) +} + +#[allow(unused_variables)] +fn is_cat_str_binary(type_left: &DataType, type_right: &DataType) -> bool { + #[cfg(feature = "dtype-categorical")] + { + matches_any_order!( + type_left, + type_right, + DataType::Utf8, + DataType::Categorical(_) + ) + } + #[cfg(not(feature = "dtype-categorical"))] + { + false + } +} + +fn str_numeric_arithmetic(type_left: &DataType, type_right: &DataType) -> PolarsResult<()> { + if type_left.is_numeric() && matches!(type_right, DataType::Utf8) + || type_right.is_numeric() && matches!(type_left, DataType::Utf8) + { + Err(PolarsError::ComputeError( + "Arithmetic on string and numeric not allowed. Try an explicit cast first.".into(), + )) + } else { + Ok(()) + } +} + +fn process_list_arithmetic( + type_left: DataType, + type_right: DataType, + node_left: Node, + node_right: Node, + op: Operator, + expr_arena: &mut Arena, +) -> PolarsResult> { + match (&type_left, &type_right) { + (DataType::List(inner), _) => { + if type_right != **inner { + let new_node_right = expr_arena.add(AExpr::Cast { + expr: node_right, + data_type: *inner.clone(), + strict: false, + }); + + Ok(Some(AExpr::BinaryExpr { + left: node_left, + op, + right: new_node_right, + })) + } else { + Ok(None) + } + } + (_, DataType::List(inner)) => { + if type_left != **inner { + let new_node_left = expr_arena.add(AExpr::Cast { + expr: node_left, + data_type: *inner.clone(), + strict: false, + }); + + Ok(Some(AExpr::BinaryExpr { + left: new_node_left, + op, + right: node_right, + })) + } else { + Ok(None) + } + } + _ => unreachable!(), + } +} + +pub(super) fn process_binary( + expr_arena: &mut Arena, + lp_arena: &Arena, + lp_node: Node, + node_left: Node, + op: Operator, + node_right: Node, +) -> PolarsResult> { + let input_schema = get_schema(lp_arena, lp_node); + let (left, type_left): (&AExpr, DataType) = + unpack!(get_aexpr_and_type(expr_arena, node_left, &input_schema)); + let (right, type_right): (&AExpr, DataType) = + unpack!(get_aexpr_and_type(expr_arena, node_right, &input_schema)); + unpack!(early_escape(&type_left, &type_right)); + + use DataType::*; + // don't coerce string with number comparisons. They must error + match (&type_left, &type_right, op) { + #[cfg(not(feature = "dtype-categorical"))] + (DataType::Utf8, dt, op) | (dt, DataType::Utf8, op) + if op.is_comparison() && dt.is_numeric() => + { + return Ok(None) + } + #[cfg(feature = "dtype-categorical")] + (Utf8 | Categorical(_), dt, op) | (dt, Utf8 | Categorical(_), op) + if op.is_comparison() && dt.is_numeric() => + { + return Ok(None) + } + #[cfg(feature = "dtype-date")] + (Date, Utf8, op) if op.is_comparison() => err_date_str_compare()?, + #[cfg(feature = "dtype-datetime")] + (Datetime(_, _), Utf8, op) if op.is_comparison() => err_date_str_compare()?, + #[cfg(feature = "dtype-time")] + (Time, Utf8, op) if op.is_comparison() => err_date_str_compare()?, + // structs can be arbitrarily nested, leave the complexity to the caller for now. + #[cfg(feature = "dtype-struct")] + (Struct(_), Struct(_), _op) => return Ok(None), + _ => {} + } + let compare_cat_to_string = compares_cat_to_string(&type_left, &type_right, op); + let datetime_arithmetic = is_datetime_arithmetic(&type_left, &type_right, op); + let list_arithmetic = is_list_arithmetic(&type_left, &type_right, op); + str_numeric_arithmetic(&type_left, &type_right)?; + + // Special path for list arithmetic + if list_arithmetic { + return process_list_arithmetic( + type_left, type_right, node_left, node_right, op, expr_arena, + ); + } + + // All early return paths + if compare_cat_to_string + || datetime_arithmetic + || early_escape(&type_left, &type_right).is_none() + { + Ok(None) + } else { + // Coerce types: + + let st = unpack!(get_supertype(&type_left, &type_right)); + let mut st = modify_supertype(st, left, right, &type_left, &type_right); + + if is_cat_str_binary(&type_left, &type_right) { + st = Utf8 + } + + // only cast if the type is not already the super type. + // this can prevent an expensive flattening and subsequent aggregation + // in a groupby context. To be able to cast the groups need to be + // flattened + let new_node_left = if type_left != st { + expr_arena.add(AExpr::Cast { + expr: node_left, + data_type: st.clone(), + strict: false, + }) + } else { + node_left + }; + let new_node_right = if type_right != st { + expr_arena.add(AExpr::Cast { + expr: node_right, + data_type: st, + strict: false, + }) + } else { + node_right + }; + + Ok(Some(AExpr::BinaryExpr { + left: new_node_left, + op, + right: new_node_right, + })) + } +} diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion.rs b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs similarity index 66% rename from polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion.rs rename to polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index 29b1ffe79228..f48ba049226d 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -1,3 +1,5 @@ +mod binary; + use std::borrow::Cow; use polars_core::prelude::*; @@ -5,6 +7,7 @@ use polars_core::utils::get_supertype; use super::*; use crate::dsl::function_expr::FunctionExpr; +use crate::logical_plan::optimizer::type_coercion::binary::process_binary; use crate::logical_plan::Context; use crate::utils::is_scan; @@ -201,176 +204,7 @@ impl OptimizationRule for TypeCoercionRule { left: node_left, op, right: node_right, - } => { - let input_schema = get_schema(lp_arena, lp_node); - let (left, type_left) = - unpack!(get_aexpr_and_type(expr_arena, node_left, &input_schema)); - - let (right, type_right) = - unpack!(get_aexpr_and_type(expr_arena, node_right, &input_schema)); - unpack!(early_escape(&type_left, &type_right)); - - // don't coerce string with number comparisons. They must error - match (&type_left, &type_right, op) { - #[cfg(not(feature = "dtype-categorical"))] - (DataType::Utf8, dt, op) | (dt, DataType::Utf8, op) - if op.is_comparison() && dt.is_numeric() => - { - return Ok(None) - } - #[cfg(feature = "dtype-categorical")] - (DataType::Utf8 | DataType::Categorical(_), dt, op) - | (dt, DataType::Utf8 | DataType::Categorical(_), op) - if op.is_comparison() && dt.is_numeric() => - { - return Ok(None) - } - #[cfg(feature = "dtype-date")] - (DataType::Date, DataType::Utf8, op) if op.is_comparison() => { - err_date_str_compare()? - } - #[cfg(feature = "dtype-datetime")] - (DataType::Datetime(_, _), DataType::Utf8, op) if op.is_comparison() => { - err_date_str_compare()? - } - #[cfg(feature = "dtype-time")] - (DataType::Time, DataType::Utf8, op) if op.is_comparison() => { - err_date_str_compare()? - } - // structs can be arbitrarily nested, leave the complexity to the caller for now. - #[cfg(feature = "dtype-struct")] - (DataType::Struct(_), DataType::Struct(_), _op) => return Ok(None), - _ => {} - } - - #[allow(unused_mut, unused_assignments)] - let mut compare_cat_to_string = false; - #[cfg(feature = "dtype-categorical")] - { - compare_cat_to_string = matches!( - op, - Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::Lt - | Operator::GtEq - | Operator::LtEq - ) && (matches!(type_left, DataType::Categorical(_)) - && type_right == DataType::Utf8) - || (type_left == DataType::Utf8 - && matches!(type_right, DataType::Categorical(_))); - } - - let datetime_arithmetic = matches!(op, Operator::Minus | Operator::Plus) - && matches!( - (&type_left, &type_right), - (DataType::Datetime(_, _), DataType::Duration(_)) - | (DataType::Duration(_), DataType::Datetime(_, _)) - | (DataType::Date, DataType::Duration(_)) - | (DataType::Duration(_), DataType::Date) - ); - - let list_arithmetic = op.is_arithmetic() - && matches!( - (&type_left, &type_right), - (DataType::List(_), _) | (_, DataType::List(_)) - ); - - // Special path for list arithmetic - if list_arithmetic { - match (&type_left, &type_right) { - (DataType::List(inner), _) => { - return if type_right != **inner { - let new_node_right = expr_arena.add(AExpr::Cast { - expr: node_right, - data_type: *inner.clone(), - strict: false, - }); - - Ok(Some(AExpr::BinaryExpr { - left: node_left, - op, - right: new_node_right, - })) - } else { - Ok(None) - }; - } - (_, DataType::List(inner)) => { - return if type_left != **inner { - let new_node_left = expr_arena.add(AExpr::Cast { - expr: node_left, - data_type: *inner.clone(), - strict: false, - }); - - Ok(Some(AExpr::BinaryExpr { - left: new_node_left, - op, - right: node_right, - })) - } else { - Ok(None) - }; - } - _ => unreachable!(), - } - } - - if compare_cat_to_string - || datetime_arithmetic - || early_escape(&type_left, &type_right).is_none() - { - None - } else { - let st = unpack!(get_supertype(&type_left, &type_right)); - let mut st = modify_supertype(st, left, right, &type_left, &type_right); - - #[allow(unused_mut, unused_assignments)] - let mut cat_str_arithmetic = false; - - #[cfg(feature = "dtype-categorical")] - { - cat_str_arithmetic = (matches!(type_left, DataType::Categorical(_)) - && type_right == DataType::Utf8) - || (type_left == DataType::Utf8 - && matches!(type_right, DataType::Categorical(_))); - } - - if cat_str_arithmetic { - st = DataType::Utf8 - } - - // only cast if the type is not already the super type. - // this can prevent an expensive flattening and subsequent aggregation - // in a groupby context. To be able to cast the groups need to be - // flattened - let new_node_left = if type_left != st { - expr_arena.add(AExpr::Cast { - expr: node_left, - data_type: st.clone(), - strict: false, - }) - } else { - node_left - }; - let new_node_right = if type_right != st { - expr_arena.add(AExpr::Cast { - expr: node_right, - data_type: st, - strict: false, - }) - } else { - node_right - }; - - Some(AExpr::BinaryExpr { - left: new_node_left, - op, - right: new_node_right, - }) - } - } + } => return process_binary(expr_arena, lp_arena, lp_node, node_left, op, node_right), #[cfg(feature = "is_in")] AExpr::Function { function: FunctionExpr::IsIn, diff --git a/polars/polars-utils/src/lib.rs b/polars/polars-utils/src/lib.rs index a530b019af14..137e246eaa1b 100644 --- a/polars/polars-utils/src/lib.rs +++ b/polars/polars-utils/src/lib.rs @@ -24,5 +24,6 @@ pub type IdxSize = u32; #[cfg(feature = "bigidx")] pub type IdxSize = u64; +pub mod macros; #[cfg(target_family = "wasm")] pub mod wasm; diff --git a/polars/polars-utils/src/macros.rs b/polars/polars-utils/src/macros.rs new file mode 100644 index 000000000000..6c0d77917fa7 --- /dev/null +++ b/polars/polars-utils/src/macros.rs @@ -0,0 +1,7 @@ +#[macro_export] +macro_rules! matches_any_order { + ($expression1:expr, $expression2:expr, $( $pattern1:pat_param )|+, $( $pattern2:pat_param )|+) => { + (matches!($expression1, $( $pattern1 )|+) && matches!($expression2, $( $pattern2)|+)) || + matches!($expression2, $( $pattern1 ) |+) && matches!($expression1, $( $pattern2)|+) + } +} diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index a49fb5b08700..53248a14ecfe 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -438,3 +438,11 @@ def test_take_negative_index_is_oob() -> None: df = pl.DataFrame({"value": [1, 2, 3]}) with pytest.raises(pl.ComputeError, match=r"Out of bounds"): df["value"].take(-1) + + +def test_string_numeric_arithmetic_err() -> None: + df = pl.DataFrame({"s": ["x"]}) + with pytest.raises( + pl.ComputeError, match=r"Arithmetic on string and numeric not allowed" + ): + df.select(pl.col("s") + 1)