Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust,python,cli): support negative indexing and expressions for LEFT, RIGHT and SUBSTR SQL string funcs #13888

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 65 additions & 45 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
Expand Down Expand Up @@ -768,15 +768,15 @@ impl SQLFunctionVisitor<'_> {
1 => self.visit_unary(|e| e.round(0)),
2 => self.try_visit_binary(|e, decimals| {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Int64(n)) => n as u32,
_ => {
polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]);
}
Expr::Literal(LiteralValue::Int64(n)) => {
if n >= 0 { n as u32 } else {
polars_bail!(InvalidOperation: "Round does not (yet) support negative 'decimals': {}", function.args[1])
}
},
_ => polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]),
}))
}),
_ => {
polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len());
},
_ => polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()),
},
Sign => self.visit_unary(Expr::sign),
Sqrt => self.visit_unary(Expr::sqrt),
Expand Down Expand Up @@ -860,13 +860,21 @@ impl SQLFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(lit(0), match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let len = if n > 0 { lit(n) } else { (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) };
e.str().slice(lit(0), len)
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Left: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
when(length.clone().gt_eq(lit(0)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thing we have CSE!

Copy link
Collaborator Author

@alexander-beedie alexander-beedie Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely! Mapping back to SQL behaviour was awkward enough without having to worry about eating a performance hit too 😅

.then(e.clone().str().slice(lit(0), length.clone().abs()))
.otherwise(e.clone().str().slice(lit(0), (e.clone().str().len_chars() + length.clone()).clip_min(lit(0))))
}
}))
}),
}
)}),
Length => self.visit_unary(|e| e.str().len_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => match function.args.len() {
Expand Down Expand Up @@ -902,51 +910,63 @@ impl SQLFunctionVisitor<'_> {
3 => self.try_visit_ternary(|e, old, new| {
Ok(e.str().replace_all(old, new, true))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Replace: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Replace: {}", function.args.len()),
},
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(-n),
Ok(match length {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(0)) => lit(""),
Expr::Literal(LiteralValue::Int64(n)) => {
let offset = if n < 0 { lit(n.abs()) } else { e.clone().str().len_chars().cast(DataType::Int32) - lit(n) };
e.str().slice(offset, lit(Null))
},
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'n_chars' for Right: {}", function.args[1]),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
when(length.clone().lt(lit(0)))
.then(e.clone().str().slice(length.clone().abs(), lit(Null)))
.otherwise(e.clone().str().slice(e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), lit(Null)))
}
}, lit(Null)))
}),
}
)}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for RTrim: {}",
function.args.len()
),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for RTrim: {}", function.args.len()),
},
StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)),
Substring => match function.args.len() {
// note that SQL is 1-indexed, not 0-indexed
// note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) ,
_ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
}, lit(Null)))
Ok(match start {
Expr::Literal(Null) => lit(Null),
Expr::Literal(LiteralValue::Int64(n)) if n <= 0 => e,
Expr::Literal(LiteralValue::Int64(n)) => e.str().slice(lit(n - 1), lit(Null)),
Expr::Literal(_) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
_ => start.clone() + lit(1),
})
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1),
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
}))
3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| {
Ok(match (start.clone(), length.clone()) {
(Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null),
(_, Expr::Literal(LiteralValue::Int64(n))) if n < 0 => {
polars_bail!(InvalidOperation: "Substring does not support negative length: {}", function.args[2])
},
(Expr::Literal(LiteralValue::Int64(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()),
(Expr::Literal(LiteralValue::Int64(n)), _) => {
e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0)))
},
(Expr::Literal(_), _) => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
(_, Expr::Literal(LiteralValue::Float64(_))) => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[1])
},
_ => {
let adjusted_start = start.clone() - lit(1);
when(adjusted_start.clone().lt(lit(0)))
.then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0))))
.otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone()))
}
})
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for Substring: {}", function.args.len()),
}
Expand Down
21 changes: 10 additions & 11 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,16 @@ impl SQLExprVisitor<'_> {
/// e.g. +column or -column
fn visit_unary_op(&mut self, op: &UnaryOperator, expr: &SQLExpr) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
Ok(match op {
UnaryOperator::Plus => lit(0) + expr,
UnaryOperator::Minus => lit(0) - expr,
UnaryOperator::Not => expr.not(),
Ok(match (op, expr.clone()) {
// simplify the parse tree by special-casing common unary +/- ops
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Int64(n))) => lit(n),
(UnaryOperator::Plus, Expr::Literal(LiteralValue::Float64(n))) => lit(n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Int64(n))) => lit(-n),
(UnaryOperator::Minus, Expr::Literal(LiteralValue::Float64(n))) => lit(-n),
// general case
(UnaryOperator::Plus, _) => lit(0) + expr,
(UnaryOperator::Minus, _) => lit(0) - expr,
(UnaryOperator::Not, _) => expr.not(),
other => polars_bail!(InvalidOperation: "Unary operator {:?} is not supported", other),
})
}
Expand Down Expand Up @@ -609,27 +615,20 @@ impl SQLExprVisitor<'_> {
/// Visit a SQL `ARRAY_AGG` expression.
fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult<Expr> {
let mut base = self.visit_expr(&expr.expr)?;

if let Some(order_by) = expr.order_by.as_ref() {
let (order_by, descending) = self.visit_order_by(order_by)?;
base = base.sort_by(order_by, descending);
}

if let Some(limit) = &expr.limit {
let limit = match self.visit_expr(limit)? {
Expr::Literal(LiteralValue::UInt32(n)) => n as usize,
Expr::Literal(LiteralValue::UInt64(n)) => n as usize,
Expr::Literal(LiteralValue::Int32(n)) => n as usize,
Expr::Literal(LiteralValue::Int64(n)) => n as usize,
_ => polars_bail!(ComputeError: "limit in ARRAY_AGG must be a positive integer"),
};
base = base.head(Some(limit));
}

if expr.distinct {
base = base.unique_stable();
}

polars_ensure!(
!expr.within_group,
ComputeError: "ARRAY_AGG WITHIN GROUP is not yet supported"
Expand Down
13 changes: 9 additions & 4 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,15 @@ def test_round_ndigits(decimals: int, expected: list[float]) -> None:

def test_round_ndigits_errors() -> None:
df = pl.DataFrame({"n": [99.999]})
with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")
with pl.SQLContext(df=df, eager_execution=True) as ctx:
with pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: ??"
):
ctx.execute("SELECT ROUND(n,'??') AS n FROM df")
with pytest.raises(
InvalidOperationError, match="Round .* negative 'decimals': -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_stddev_variance() -> None:
Expand Down
115 changes: 99 additions & 16 deletions py-polars/tests/unit/sql/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,82 @@ def test_string_left_right_reverse() -> None:
"r": ["de", "bc", "a", None],
"rev": ["edcba", "cba", "a", None],
}
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "-1")):
for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "6.66")):
with pytest.raises(
InvalidOperationError,
match=f"Invalid 'length' for {func.capitalize()}: {invalid}",
match=f"Invalid 'n_chars' for {func.capitalize()}: {invalid}",
):
ctx.execute(f"""SELECT {func}(txt,{invalid}) FROM df""").collect()


def test_string_left_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
LEFT("s",-50) AS l0, -- empty string
LEFT("s",-3) AS l1, -- all but last three chars
LEFT("s",SIGN(-1)) AS l2, -- all but last char (expr => -1)
LEFT("s",0) AS l3, -- empty string
LEFT("s",NULL) AS l4, -- null
LEFT("s",1) AS l5, -- first char
LEFT("s",SIGN(1)) AS l6, -- first char (expr => 1)
LEFT("s",3) AS l7, -- first three chars
LEFT("s",50) AS l8, -- entire string
LEFT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["alpha", "alpha"],
"l2": ["alphabe", "alphabe"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["a", "a"],
"l6": ["a", "a"],
"l7": ["alp", "alp"],
"l8": ["alphabet", "alphabet"],
"l9": ["al", "alphab"],
}


def test_string_right_negative_expr() -> None:
# negative values and expressions
df = pl.DataFrame({"s": ["alphabet", "alphabet"], "n": [-6, 6]})
with pl.SQLContext(df=df, eager_execution=True) as sql:
res = sql.execute(
"""
SELECT
RIGHT("s",-50) AS l0, -- empty string
RIGHT("s",-3) AS l1, -- all but first three chars
RIGHT("s",SIGN(-1)) AS l2, -- all but first char (expr => -1)
RIGHT("s",0) AS l3, -- empty string
RIGHT("s",NULL) AS l4, -- null
RIGHT("s",1) AS l5, -- last char
RIGHT("s",SIGN(1)) AS l6, -- last char (expr => 1)
RIGHT("s",3) AS l7, -- last three chars
RIGHT("s",50) AS l8, -- entire string
RIGHT("s","n") AS l9, -- from other col
FROM df
"""
)
assert res.to_dict(as_series=False) == {
"l0": ["", ""],
"l1": ["habet", "habet"],
"l2": ["lphabet", "lphabet"],
"l3": ["", ""],
"l4": [None, None],
"l5": ["t", "t"],
"l6": ["t", "t"],
"l7": ["bet", "bet"],
"l8": ["alphabet", "alphabet"],
"l9": ["et", "phabet"],
}


def test_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京", ""]})

Expand Down Expand Up @@ -254,38 +322,53 @@ def test_string_replace() -> None:


def test_string_substr() -> None:
df = pl.DataFrame({"scol": ["abcdefg", "abcde", "abc", None]})
df = pl.DataFrame(
{"scol": ["abcdefg", "abcde", "abc", None], "n": [-2, 3, 2, None]}
)
with pl.SQLContext(df=df) as ctx:
res = ctx.execute(
"""
SELECT
-- note: sql is 1-indexed
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,1) AS s1,
SUBSTR(scol,2) AS s2,
SUBSTR(scol,3) AS s3,
SUBSTR(scol,1,5) AS s1_5,
SUBSTR(scol,2,2) AS s2_2,
SUBSTR(scol,3,1) AS s3_1,
SUBSTR(scol,-3) AS "s-3",
SUBSTR(scol,-3,3) AS "s-3_3",
SUBSTR(scol,-3,4) AS "s-3_4",
SUBSTR(scol,-3,5) AS "s-3_5",
SUBSTR(scol,-10,13) AS "s-10_13",
SUBSTR(scol,"n",2) AS "s-n2",
SUBSTR(scol,2,"n"+3) AS "s-2n3"
FROM df
"""
).collect()

with pytest.raises(
InvalidOperationError,
match="Substring does not support negative length: -99",
):
ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df")

assert res.to_dict(as_series=False) == {
"s1": ["abcdefg", "abcde", "abc", None],
"s2": ["bcdefg", "bcde", "bc", None],
"s3": ["cdefg", "cde", "c", None],
"s1_5": ["abcde", "abcde", "abc", None],
"s2_2": ["bc", "bc", "bc", None],
"s3_1": ["c", "c", "c", None],
"s-3": ["abcdefg", "abcde", "abc", None],
"s-3_3": ["", "", "", None],
"s-3_4": ["", "", "", None],
"s-3_5": ["a", "a", "a", None],
"s-10_13": ["ab", "ab", "ab", None],
"s-n2": ["", "cd", "bc", None],
"s-2n3": ["b", "bcde", "bc", None],
}

# negative indexes are expected to be invalid
with pytest.raises(
InvalidOperationError,
match="Invalid 'start' for Substring: -1",
), pl.SQLContext(df=df) as ctx:
ctx.execute("SELECT SUBSTR(scol,-1) FROM df")


def test_string_trim(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)
Expand Down