diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 6fc429d6aab7..b14e5add5131 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,4 +1,4 @@ -use polars_core::prelude::{polars_bail, polars_err, PolarsResult}; +use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult}; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] use polars_lazy::dsl::ListNameSpaceExtension; @@ -768,15 +768,15 @@ impl SQLFunctionVisitor<'_> { 1 => self.visit_unary(|e| e.round(0)), 2 => self.try_visit_binary(|e, decimals| { Ok(e.round(match decimals { - Expr::Literal(LiteralValue::Int64(n)) => n as u32, - _ => { - polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]); - } + Expr::Literal(LiteralValue::Int64(n)) => { + if n >= 0 { n as u32 } else { + polars_bail!(InvalidOperation: "Round does not (yet) support negative 'decimals': {}", function.args[1]) + } + }, + _ => polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]), })) }), - _ => { - polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()); - }, + _ => polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()), }, Sign => self.visit_unary(Expr::sign), Sqrt => self.visit_unary(Expr::sqrt), @@ -860,13 +860,21 @@ impl SQLFunctionVisitor<'_> { #[cfg(feature = "nightly")] InitCap => self.visit_unary(|e| e.str().to_titlecase()), Left => self.try_visit_binary(|e, length| { - Ok(e.str().slice(lit(0), match length { - Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64), + Ok(match length { + Expr::Literal(Null) => lit(Null), + Expr::Literal(LiteralValue::Int64(0)) => lit(""), + Expr::Literal(LiteralValue::Int64(n)) => { + let len = if n > 0 { lit(n) } else { (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) }; + e.str().slice(lit(0), len) + }, + Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Left: {}", function.args[1]), _ => { - polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]); + when(length.clone().gt_eq(lit(0))) + .then(e.clone().str().slice(lit(0), length.clone().abs())) + .otherwise(e.clone().str().slice(lit(0), (e.clone().str().len_chars() + length.clone()).clip_min(lit(0)))) } - })) - }), + } + )}), Length => self.visit_unary(|e| e.str().len_chars()), Lower => self.visit_unary(|e| e.str().to_lowercase()), LTrim => match function.args.len() { @@ -902,51 +910,63 @@ impl SQLFunctionVisitor<'_> { 3 => self.try_visit_ternary(|e, old, new| { Ok(e.str().replace_all(old, new, true)) }), - _ => polars_bail!(InvalidOperation: - "Invalid number of arguments for Replace: {}", - function.args.len() - ), + _ => polars_bail!(InvalidOperation: "Invalid number of arguments for Replace: {}", function.args.len()), }, Reverse => self.visit_unary(|e| e.str().reverse()), Right => self.try_visit_binary(|e, length| { - Ok(e.str().slice( match length { - Expr::Literal(LiteralValue::Int64(n)) => lit(-n), + Ok(match length { + Expr::Literal(Null) => lit(Null), + Expr::Literal(LiteralValue::Int64(0)) => lit(""), + Expr::Literal(LiteralValue::Int64(n)) => { + let offset = if n < 0 { lit(n.abs()) } else { e.clone().str().len_chars().cast(DataType::Int32) - lit(n) }; + e.str().slice(offset, lit(Null)) + }, + Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Right: {}", function.args[1]), _ => { - polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]); + when(length.clone().lt(lit(0))) + .then(e.clone().str().slice(length.clone().abs(), lit(Null))) + .otherwise(e.clone().str().slice(e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), lit(Null))) } - }, lit(Null))) - }), + } + )}), RTrim => match function.args.len() { 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), - _ => polars_bail!(InvalidOperation: - "Invalid number of arguments for RTrim: {}", - function.args.len() - ), + _ => polars_bail!(InvalidOperation: "Invalid number of arguments for RTrim: {}", function.args.len()), }, StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)), Substring => match function.args.len() { - // note that SQL is 1-indexed, not 0-indexed + // note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments 2 => self.try_visit_binary(|e, start| { - Ok(e.str().slice( - match start { - Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) , - _ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]), - }, lit(Null))) + Ok(match start { + Expr::Literal(Null) => lit(Null), + Expr::Literal(LiteralValue::Int64(n)) if n <= 0 => e, + Expr::Literal(LiteralValue::Int64(n)) => e.str().slice(lit(n - 1), lit(Null)), + Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]), + _ => start.clone() + lit(1), + }) }), - 3 => self.try_visit_ternary(|e, start, length| { - Ok(e.str().slice( - match start { - Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1), - _ => { - polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]); - } - }, match length { - Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64), - _ => { - polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]); - } - })) + 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| { + Ok(match (start.clone(), length.clone()) { + (Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null), + (_, Expr::Literal(LiteralValue::Int64(n))) if n < 0 => { + polars_bail!(InvalidOperation: "Substring does not support negative length: {}", function.args[2]) + }, + (Expr::Literal(LiteralValue::Int64(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()), + (Expr::Literal(LiteralValue::Int64(n)), _) => { + e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0))) + }, + (Expr::Literal(_), _) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]), + (_, Expr::Literal(LiteralValue::Float64(_))) => { + polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[1]) + }, + _ => { + let adjusted_start = start.clone() - lit(1); + when(adjusted_start.clone().lt(lit(0))) + .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0)))) + .otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone())) + } + }) }), _ => polars_bail!(InvalidOperation: "Invalid number of arguments for Substring: {}", function.args.len()), } diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 9cf5c3575961..9a4b8f003d4f 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -378,10 +378,16 @@ impl SQLExprVisitor<'_> { /// e.g. +column or -column fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult { let expr = self.visit_expr(expr)?; - Ok(match op { - UnaryOperator::Plus => lit(0) + expr, - UnaryOperator::Minus => lit(0) - expr, - UnaryOperator::Not => expr.not(), + Ok(match (op, expr.clone()) { + // simplify the parse tree by special-casing common unary +/- ops + (UnaryOperator::Plus, Expr::Literal(LiteralValue::Int64(n))) => lit(n), + (UnaryOperator::Plus, Expr::Literal(LiteralValue::Float64(n))) => lit(n), + (UnaryOperator::Minus, Expr::Literal(LiteralValue::Int64(n))) => lit(-n), + (UnaryOperator::Minus, Expr::Literal(LiteralValue::Float64(n))) => lit(-n), + // general case + (UnaryOperator::Plus, _) => lit(0) + expr, + (UnaryOperator::Minus, _) => lit(0) - expr, + (UnaryOperator::Not, _) => expr.not(), other => polars_bail!(InvalidOperation: "Unary operator {:?} is not supported", other), }) } @@ -609,27 +615,20 @@ impl SQLExprVisitor<'_> { /// Visit a SQL `ARRAY_AGG` expression. fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult { let mut base = self.visit_expr(&expr.expr)?; - if let Some(order_by) = expr.order_by.as_ref() { let (order_by, descending) = self.visit_order_by(order_by)?; base = base.sort_by(order_by, descending); } - if let Some(limit) = &expr.limit { let limit = match self.visit_expr(limit)? { - Expr::Literal(LiteralValue::UInt32(n)) => n as usize, - Expr::Literal(LiteralValue::UInt64(n)) => n as usize, - Expr::Literal(LiteralValue::Int32(n)) => n as usize, Expr::Literal(LiteralValue::Int64(n)) => n as usize, _ => polars_bail!(ComputeError: "limit in ARRAY_AGG must be a positive integer"), }; base = base.head(Some(limit)); } - if expr.distinct { base = base.unique_stable(); } - polars_ensure!( !expr.within_group, ComputeError: "ARRAY_AGG WITHIN GROUP is not yet supported" diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index 7d7938465ca5..09cbcc237739 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -107,10 +107,15 @@ def test_round_ndigits(decimals: int, expected: list[float]) -> None: def test_round_ndigits_errors() -> None: df = pl.DataFrame({"n": [99.999]}) - with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises( - InvalidOperationError, match="Invalid 'decimals' for Round: -1" - ): - ctx.execute("SELECT ROUND(n,-1) AS n FROM df") + with pl.SQLContext(df=df, eager_execution=True) as ctx: + with pytest.raises( + InvalidOperationError, match="Invalid 'decimals' for Round: ??" + ): + ctx.execute("SELECT ROUND(n,'??') AS n FROM df") + with pytest.raises( + InvalidOperationError, match="Round .* negative 'decimals': -1" + ): + ctx.execute("SELECT ROUND(n,-1) AS n FROM df") def test_stddev_variance() -> None: diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index 72a74742c6c4..e3b60b98cb0f 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -98,14 +98,82 @@ def test_string_left_right_reverse() -> None: "r": ["de", "bc", "a", None], "rev": ["edcba", "cba", "a", None], } - for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "-1")): + for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "6.66")): with pytest.raises( InvalidOperationError, - match=f"Invalid 'length' for {func.capitalize()}: {invalid}", + match=f"Invalid 'n_chars' for {func.capitalize()}: {invalid}", ): ctx.execute(f"""SELECT {func}(txt,{invalid}) FROM df""").collect() +def test_string_left_negative_expr() -> None: + # negative values and expressions + df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]}) + with pl.SQLContext(df=df, eager_execution=True) as sql: + res = sql.execute( + """ + SELECT + LEFT("s",-50) AS l0, -- empty string + LEFT("s",-3) AS l1, -- all but last three chars + LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1) + LEFT("s",0) AS l3, -- empty string + LEFT("s",NULL) AS l4, -- null + LEFT("s",1) AS l5, -- first char + LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1) + LEFT("s",3) AS l7, -- first three chars + LEFT("s",50) AS l8, -- entire string + LEFT("s","n") AS l9, -- from other col + FROM df + """ + ) + assert res.to_dict(as_series=False) == { + "l0": ["", ""], + "l1": ["alpha", "alpha"], + "l2": ["alphabe", "alphabe"], + "l3": ["", ""], + "l4": [None, None], + "l5": ["a", "a"], + "l6": ["a", "a"], + "l7": ["alp", "alp"], + "l8": ["alphabet", "alphabet"], + "l9": ["al", "alphab"], + } + + +def test_string_right_negative_expr() -> None: + # negative values and expressions + df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]}) + with pl.SQLContext(df=df, eager_execution=True) as sql: + res = sql.execute( + """ + SELECT + RIGHT("s",-50) AS l0, -- empty string + RIGHT("s",-3) AS l1, -- all but first three chars + RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1) + RIGHT("s",0) AS l3, -- empty string + RIGHT("s",NULL) AS l4, -- null + RIGHT("s",1) AS l5, -- last char + RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1) + RIGHT("s",3) AS l7, -- last three chars + RIGHT("s",50) AS l8, -- entire string + RIGHT("s","n") AS l9, -- from other col + FROM df + """ + ) + assert res.to_dict(as_series=False) == { + "l0": ["", ""], + "l1": ["habet", "habet"], + "l2": ["lphabet", "lphabet"], + "l3": ["", ""], + "l4": [None, None], + "l5": ["t", "t"], + "l6": ["t", "t"], + "l7": ["bet", "bet"], + "l8": ["alphabet", "alphabet"], + "l9": ["et", "phabet"], + } + + def test_string_lengths() -> None: df = pl.DataFrame({"words": ["Café", None, "東京", ""]}) @@ -254,22 +322,37 @@ def test_string_replace() -> None: def test_string_substr() -> None: - df = pl.DataFrame({"scol": ["abcdefg", "abcde", "abc", None]}) + df = pl.DataFrame( + {"scol": ["abcdefg", "abcde", "abc", None], "n": [-2, 3, 2, None]} + ) with pl.SQLContext(df=df) as ctx: res = ctx.execute( """ SELECT -- note: sql is 1-indexed - SUBSTR(scol,1) AS s1, - SUBSTR(scol,2) AS s2, - SUBSTR(scol,3) AS s3, - SUBSTR(scol,1,5) AS s1_5, - SUBSTR(scol,2,2) AS s2_2, - SUBSTR(scol,3,1) AS s3_1, + SUBSTR(scol,1) AS s1, + SUBSTR(scol,2) AS s2, + SUBSTR(scol,3) AS s3, + SUBSTR(scol,1,5) AS s1_5, + SUBSTR(scol,2,2) AS s2_2, + SUBSTR(scol,3,1) AS s3_1, + SUBSTR(scol,-3) AS "s-3", + SUBSTR(scol,-3,3) AS "s-3_3", + SUBSTR(scol,-3,4) AS "s-3_4", + SUBSTR(scol,-3,5) AS "s-3_5", + SUBSTR(scol,-10,13) AS "s-10_13", + SUBSTR(scol,"n",2) AS "s-n2", + SUBSTR(scol,2,"n"+3) AS "s-2n3" FROM df """ ).collect() + with pytest.raises( + InvalidOperationError, + match="Substring does not support negative length: -99", + ): + ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df") + assert res.to_dict(as_series=False) == { "s1": ["abcdefg", "abcde", "abc", None], "s2": ["bcdefg", "bcde", "bc", None], @@ -277,15 +360,15 @@ def test_string_substr() -> None: "s1_5": ["abcde", "abcde", "abc", None], "s2_2": ["bc", "bc", "bc", None], "s3_1": ["c", "c", "c", None], + "s-3": ["abcdefg", "abcde", "abc", None], + "s-3_3": ["", "", "", None], + "s-3_4": ["", "", "", None], + "s-3_5": ["a", "a", "a", None], + "s-10_13": ["ab", "ab", "ab", None], + "s-n2": ["", "cd", "bc", None], + "s-2n3": ["b", "bcde", "bc", None], } - # negative indexes are expected to be invalid - with pytest.raises( - InvalidOperationError, - match="Invalid 'start' for Substring: -1", - ), pl.SQLContext(df=df) as ctx: - ctx.execute("SELECT SUBSTR(scol,-1) FROM df") - def test_string_trim(foods_ipc_path: Path) -> None: lf = pl.scan_ipc(foods_ipc_path)