From aa1950c226f00a90092aea23cdbee247240775b9 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 30 Jul 2024 13:35:27 +0400 Subject: [PATCH] feat: Add `SQL` interface support for PostgreSQL dollar-quoted string literals (#17940) --- crates/polars-sql/src/sql_expr.rs | 28 ++++++++++++----------- py-polars/tests/unit/sql/test_literals.py | 18 ++++++++++++++- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index ffcb3dbc4998..9374a5fd3229 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -907,9 +907,10 @@ impl SQLExprVisitor<'_> { /// /// See [SQLValue] and [LiteralValue] for more details fn visit_literal(&self, value: &SQLValue) -> PolarsResult { + // note: double-quoted strings will be parsed as identifiers, not literals Ok(match value { SQLValue::Boolean(b) => lit(*b), - SQLValue::DoubleQuotedString(s) => lit(s.clone()), + SQLValue::DollarQuotedString(s) => lit(s.value.clone()), #[cfg(feature = "binary_encoding")] SQLValue::HexStringLiteral(x) => { if x.len() % 2 != 0 { @@ -929,12 +930,14 @@ impl SQLExprVisitor<'_> { }, SQLValue::SingleQuotedByteStringLiteral(b) => { // note: for PostgreSQL this represents a BIT string literal (eg: b'10101') not a BYTE string - // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser + // literal (see https://www.postgresql.org/docs/current/datatype-bit.html), but sqlparser-rs // patterned the token name after BigQuery (where b'str' really IS a byte string) bitstring_to_bytes_literal(b)? }, SQLValue::SingleQuotedString(s) => lit(s.clone()), - other => polars_bail!(SQLInterface: "value {:?} is not supported", other), + other => { + polars_bail!(SQLInterface: "value {:?} is not a supported literal type", other) + }, }) } @@ -946,6 +949,14 @@ impl SQLExprVisitor<'_> { ) -> PolarsResult { Ok(match value { SQLValue::Boolean(b) => AnyValue::Boolean(*b), + SQLValue::DollarQuotedString(s) => AnyValue::StringOwned(s.clone().value.into()), + #[cfg(feature = "binary_encoding")] + SQLValue::HexStringLiteral(x) => { + if x.len() % 2 != 0 { + polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x) + }; + AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap()) + }, SQLValue::Null => AnyValue::Null, SQLValue::Number(s, _) => { let negate = match op { @@ -968,13 +979,6 @@ impl SQLExprVisitor<'_> { } .map_err(|_| polars_err!(SQLInterface: "cannot parse literal: {:?}", s))? }, - #[cfg(feature = "binary_encoding")] - SQLValue::HexStringLiteral(x) => { - if x.len() % 2 != 0 { - polars_bail!(SQLSyntax: "hex string literal must have an even number of digits; found '{}'", x) - }; - AnyValue::BinaryOwned(hex::decode(x.clone()).unwrap()) - }, SQLValue::SingleQuotedByteStringLiteral(b) => { // note: for PostgreSQL this represents a BIT literal (eg: b'10101') not BYTE let bytes_literal = bitstring_to_bytes_literal(b)?; @@ -985,9 +989,7 @@ impl SQLExprVisitor<'_> { }, } }, - SQLValue::SingleQuotedString(s) | SQLValue::DoubleQuotedString(s) => { - AnyValue::StringOwned(s.into()) - }, + SQLValue::SingleQuotedString(s) => AnyValue::StringOwned(s.into()), other => polars_bail!(SQLInterface: "value {:?} is not currently supported", other), }) } diff --git a/py-polars/tests/unit/sql/test_literals.py b/py-polars/tests/unit/sql/test_literals.py index 4bed3f432565..805169039476 100644 --- a/py-polars/tests/unit/sql/test_literals.py +++ b/py-polars/tests/unit/sql/test_literals.py @@ -72,7 +72,7 @@ def test_bit_hex_errors() -> None: with pytest.raises( SQLInterfaceError, - match=r'NationalStringLiteral\("hmmm"\) is not supported', + match=r'NationalStringLiteral\("hmmm"\) is not a supported literal', ): pl.sql_expr("N'hmmm'") @@ -93,6 +93,22 @@ def test_bit_hex_membership() -> None: assert dff["y"].to_list() == [1, 4] +def test_dollar_quoted_literals() -> None: + df = pl.sql( + """ + SELECT + $$xyz$$ AS dq1, + $q$xyz$q$ AS dq2, + $tag$xyz$tag$ AS dq3, + $QUOTE$xyz$QUOTE$ AS dq4, + """ + ).collect() + assert df.to_dict(as_series=False) == {f"dq{n}": ["xyz"] for n in range(1, 5)} + + df = pl.sql("SELECT $$x$z$$ AS dq").collect() + assert df.item() == "x$z" + + def test_intervals() -> None: with pl.SQLContext(df=None, eager=True) as ctx: out = ctx.execute(