Skip to content

Commit

Permalink
feat(rust,python,cli): add SQL engine support for REPLACE string fu…
Browse files Browse the repository at this point in the history
…nction (#13431)
  • Loading branch information
alexander-beedie authored Jan 5, 2024
1 parent 9cef297 commit cbfa5cf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
29 changes: 23 additions & 6 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT LTRIM(column_1) from df;
/// ```
LTrim,
/// SQL 'octet_length' function (bytes)
/// SQL 'octet_length' function
/// Returns the length of a given string in bytes.
/// ```sql
/// SELECT OCTET_LENGTH(column_1) from df;
/// ```
Expand All @@ -296,6 +297,12 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT REGEXP_LIKE(column_1, 'xyz', 'i') from df;
/// ```
RegexpLike,
/// SQL 'replace' function
/// Replace a given substring with another string.
/// ```sql
/// SELECT REPLACE(column_1,'old','new') from df;
/// ```
Replace,
/// SQL 'rtrim' function
/// Strip whitespaces from the right
/// ```sql
Expand Down Expand Up @@ -622,6 +629,7 @@ impl PolarsSQLFunctions {
"ltrim" => Self::LTrim,
"octet_length" => Self::OctetLength,
"regexp_like" => Self::RegexpLike,
"replace" => Self::Replace,
"rtrim" => Self::RTrim,
"starts_with" => Self::StartsWith,
"substr" => Self::Substring,
Expand Down Expand Up @@ -738,6 +746,14 @@ impl SQLFunctionVisitor<'_> {
// String functions
// ----
BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)),
Date => match function.args.len() {
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Date: {}",
function.args.len()
),
},
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Expand Down Expand Up @@ -777,14 +793,15 @@ impl SQLFunctionVisitor<'_> {
}),
_ => polars_bail!(InvalidOperation:"Invalid number of arguments for RegexpLike: {}",function.args.len()),
},
Date => match function.args.len() {
1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())),
2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)),
Replace => match function.args.len() {
3 => self.try_visit_ternary(|e, old, new| {
Ok(e.str().replace_all(old, new, true))
}),
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for Date: {}",
"Invalid number of arguments for Replace: {}",
function.args.len()
),
},
}
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)),
Expand Down
36 changes: 29 additions & 7 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def test_sql_string_case() -> None:


def test_sql_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京"]})
df = pl.DataFrame({"words": ["Café", None, "東京", ""]})

with pl.SQLContext(frame=df) as ctx:
res = ctx.execute(
Expand All @@ -1030,15 +1030,37 @@ def test_sql_string_lengths() -> None:
).collect()

assert res.to_dict(as_series=False) == {
"words": ["Café", None, "東京"],
"n_chrs1": [4, None, 2],
"n_chrs2": [4, None, 2],
"n_chrs3": [4, None, 2],
"n_bytes": [5, None, 6],
"n_bits": [40, None, 48],
"words": ["Café", None, "東京", ""],
"n_chrs1": [4, None, 2, 0],
"n_chrs2": [4, None, 2, 0],
"n_chrs3": [4, None, 2, 0],
"n_bytes": [5, None, 6, 0],
"n_bits": [40, None, 48, 0],
}


def test_sql_string_replace() -> None:
df = pl.DataFrame({"words": ["Yemeni coffee is the best coffee", "", None]})
with pl.SQLContext(df=df) as ctx:
out = ctx.execute(
"""
SELECT
REPLACE(
REPLACE(words, 'coffee', 'tea'),
'Yemeni',
'English breakfast'
)
FROM df
"""
).collect()

res = out["words"].to_list()
assert res == ["English breakfast tea is the best tea", "", None]

with pytest.raises(InvalidOperationError, match="Invalid number of arguments"):
ctx.execute("SELECT REPLACE(words,'coffee') FROM df")


def test_sql_substr() -> None:
df = pl.DataFrame({"scol": ["abcdefg", "abcde", "abc", None]})
with pl.SQLContext(df=df) as ctx:
Expand Down

0 comments on commit cbfa5cf

Please sign in to comment.