Skip to content

Commit

Permalink
feat(rust,python,cli): support negative indexing and expressions for …
Browse files Browse the repository at this point in the history
…`LEFT`, `RIGHT` and `SUBSTR` SQL string funcs (pola-rs#13888)
  • Loading branch information
alexander-beedie authored and r-brink committed Jan 24, 2024
1 parent af2aff3 commit b600da4
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 76 deletions.
110 changes: 65 additions & 45 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
Expand Down Expand Up @@ -768,15 +768,15 @@ impl SQLFunctionVisitor<'_> {
1 => self.visit_unary(|e| e.round(0)),
2 => self.try_visit_binary(|e, decimals| {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Int64(n)) => n as u32,
_ => {
polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]);
}
Expr::Literal(LiteralValue::Int64(n)) => {
if n >= 0 { n as u32 } else {
polars_bail!(InvalidOperation: "Round does not (yet) support negative 'decimals': {}", function.args[1])
}
},
_ => polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]),
}))
}),
_ => {
polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len());
},
_ => polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()),
},
Sign => self.visit_unary(Expr::sign),
Sqrt => self.visit_unary(Expr::sqrt),
Expand Down Expand Up @@ -860,13 +860,21 @@ impl SQLFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(lit(0), match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let len = if n > 0 { lit(n) } else { (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) };
e.str().slice(lit(0), len)
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Left: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
when(length.clone().gt_eq(lit(0)))
.then(e.clone().str().slice(lit(0), length.clone().abs()))
.otherwise(e.clone().str().slice(lit(0), (e.clone().str().len_chars() + length.clone()).clip_min(lit(0))))
}
}))
}),
}
)}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => match function.args.len() {
Expand Down Expand Up @@ -902,51 +910,63 @@ impl SQLFunctionVisitor<'_> {
3 => self.try_visit_ternary(|e, old, new| {
Ok(e.str().replace_all(old, new, true))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Replace: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Replace: {}", function.args.len()),
},
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(-n),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let offset = if n < 0 { lit(n.abs()) } else { e.clone().str().len_chars().cast(DataType::Int32) - lit(n) };
e.str().slice(offset, lit(Null))
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Right: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
when(length.clone().lt(lit(0)))
.then(e.clone().str().slice(length.clone().abs(), lit(Null)))
.otherwise(e.clone().str().slice(e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), lit(Null)))
}
}, lit(Null)))
}),
}
)}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for RTrim: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for RTrim: {}", function.args.len()),
},
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
Substring => match function.args.len() {
// note that SQL is 1-indexed, not 0-indexed
// note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) ,
_ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
}, lit(Null)))
Ok(match start {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(n)) if n <= 0 => e,
Expr::Literal(LiteralValue::Int64(n)) => e.str().slice(lit(n - 1), lit(Null)),
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
_ => start.clone() + lit(1),
})
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1),
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
}))
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
Ok(match (start.clone(), length.clone()) {
(Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null),
(_, Expr::Literal(LiteralValue::Int64(n))) if n < 0 => {
polars_bail!(InvalidOperation: "Substring does not support negative length: {}", function.args[2])
},
(Expr::Literal(LiteralValue::Int64(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()),
(Expr::Literal(LiteralValue::Int64(n)), _) => {
e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0)))
},
(Expr::Literal(_), _) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
(_, Expr::Literal(LiteralValue::Float64(_))) => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[1])
},
_ => {
let adjusted_start = start.clone() - lit(1);
when(adjusted_start.clone().lt(lit(0)))
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
.otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone()))
}
})
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Substring: {}", function.args.len()),
}
Expand Down
21 changes: 10 additions & 11 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,16 @@ impl SQLExprVisitor<'_> {
/// e.g. +column or -column
fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
Ok(match op {
UnaryOperator::Plus => lit(0) + expr,
UnaryOperator::Minus => lit(0) - expr,
UnaryOperator::Not => expr.not(),
Ok(match (op, expr.clone()) {
// simplify the parse tree by special-casing common unary +/- ops
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Int64(n))) => lit(n),
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Float64(n))) => lit(n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Int64(n))) => lit(-n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Float64(n))) => lit(-n),
// general case
(UnaryOperator::Plus, _) => lit(0) + expr,
(UnaryOperator::Minus, _) => lit(0) - expr,
(UnaryOperator::Not, _) => expr.not(),
other => polars_bail!(InvalidOperation: "Unary operator {:?} is not supported", other),
})
}
Expand Down Expand Up @@ -609,27 +615,20 @@ impl SQLExprVisitor<'_> {
/// Visit a SQL `ARRAY_AGG` expression.
fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult<Expr> {
let mut base = self.visit_expr(&expr.expr)?;

if let Some(order_by) = expr.order_by.as_ref() {
let (order_by, descending) = self.visit_order_by(order_by)?;
base = base.sort_by(order_by, descending);
}

if let Some(limit) = &expr.limit {
let limit = match self.visit_expr(limit)? {
Expr::Literal(LiteralValue::UInt32(n)) => n as usize,
Expr::Literal(LiteralValue::UInt64(n)) => n as usize,
Expr::Literal(LiteralValue::Int32(n)) => n as usize,
Expr::Literal(LiteralValue::Int64(n)) => n as usize,
_ => polars_bail!(ComputeError: "limit in ARRAY_AGG must be a positive integer"),
};
base = base.head(Some(limit));
}

if expr.distinct {
base = base.unique_stable();
}

polars_ensure!(
!expr.within_group,
ComputeError: "ARRAY_AGG WITHIN GROUP is not yet supported"
Expand Down
13 changes: 9 additions & 4 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,15 @@ def test_round_ndigits(decimals: int, expected: list[float]) -> None:

def test_round_ndigits_errors() -> None:
df = pl.DataFrame({"n": [99.999]})
with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")
with pl.SQLContext(df=df, eager_execution=True) as ctx:
with pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: ??"
):
ctx.execute("SELECT ROUND(n,'??') AS n FROM df")
with pytest.raises(
InvalidOperationError, match="Round .* negative 'decimals': -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_stddev_variance() -> None:
Expand Down
115 changes: 99 additions & 16 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,82 @@ def test_string_left_right_reverse() -> None:
"r": ["de", "bc", "a", None],
"rev": ["edcba", "cba", "a", None],
}
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "-1")):
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "6.66")):
with pytest.raises(
InvalidOperationError,
match=f"Invalid 'length' for {func.capitalize()}: {invalid}",
match=f"Invalid 'n_chars' for {func.capitalize()}: {invalid}",
):
ctx.execute(f"""SELECT {func}(txt,{invalid}) FROM df""").collect()


def test_string_left_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
LEFT("s",-50) AS l0, -- empty string
LEFT("s",-3) AS l1, -- all but last three chars
LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1)
LEFT("s",0) AS l3, -- empty string
LEFT("s",NULL) AS l4, -- null
LEFT("s",1) AS l5, -- first char
LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1)
LEFT("s",3) AS l7, -- first three chars
LEFT("s",50) AS l8, -- entire string
LEFT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["alpha", "alpha"],
"l2": ["alphabe", "alphabe"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["a", "a"],
"l6": ["a", "a"],
"l7": ["alp", "alp"],
"l8": ["alphabet", "alphabet"],
"l9": ["al", "alphab"],
}


def test_string_right_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
RIGHT("s",-50) AS l0, -- empty string
RIGHT("s",-3) AS l1, -- all but first three chars
RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1)
RIGHT("s",0) AS l3, -- empty string
RIGHT("s",NULL) AS l4, -- null
RIGHT("s",1) AS l5, -- last char
RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1)
RIGHT("s",3) AS l7, -- last three chars
RIGHT("s",50) AS l8, -- entire string
RIGHT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["habet", "habet"],
"l2": ["lphabet", "lphabet"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["t", "t"],
"l6": ["t", "t"],
"l7": ["bet", "bet"],
"l8": ["alphabet", "alphabet"],
"l9": ["et", "phabet"],
}


def test_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京", ""]})

Expand Down Expand Up @@ -254,38 +322,53 @@ def test_string_replace() -> None:


def test_string_substr() -> None:
df = pl.DataFrame({"scol": ["abcdefg", "abcde", "abc", None]})
df = pl.DataFrame(
{"scol": ["abcdefg", "abcde", "abc", None], "n": [-2, 3, 2, None]}
)
with pl.SQLContext(df=df) as ctx:
res = ctx.execute(
"""
SELECT
-- note: sql is 1-indexed
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,-3) AS "s-3",
SUBSTR(scol,-3,3) AS "s-3_3",
SUBSTR(scol,-3,4) AS "s-3_4",
SUBSTR(scol,-3,5) AS "s-3_5",
SUBSTR(scol,-10,13) AS "s-10_13",
SUBSTR(scol,"n",2) AS "s-n2",
SUBSTR(scol,2,"n"+3) AS "s-2n3"
FROM df
"""
).collect()

with pytest.raises(
InvalidOperationError,
match="Substring does not support negative length: -99",
):
ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df")

assert res.to_dict(as_series=False) == {
"s1": ["abcdefg", "abcde", "abc", None],
"s2": ["bcdefg", "bcde", "bc", None],
"s3": ["cdefg", "cde", "c", None],
"s1_5": ["abcde", "abcde", "abc", None],
"s2_2": ["bc", "bc", "bc", None],
"s3_1": ["c", "c", "c", None],
"s-3": ["abcdefg", "abcde", "abc", None],
"s-3_3": ["", "", "", None],
"s-3_4": ["", "", "", None],
"s-3_5": ["a", "a", "a", None],
"s-10_13": ["ab", "ab", "ab", None],
"s-n2": ["", "cd", "bc", None],
"s-2n3": ["b", "bcde", "bc", None],
}

# negative indexes are expected to be invalid
with pytest.raises(
InvalidOperationError,
match="Invalid 'start' for Substring: -1",
), pl.SQLContext(df=df) as ctx:
ctx.execute("SELECT SUBSTR(scol,-1) FROM df")


def test_string_trim(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)
Expand Down

0 comments on commit b600da4

Please sign in to comment.