Skip to content

Commit

Permalink
feat: Initial support for SQL ARRAY literals and the UNNEST table…
Browse files Browse the repository at this point in the history
… function (#16330)
  • Loading branch information
alexander-beedie authored May 31, 2024
1 parent 03e00d5 commit 1cafcbc
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 55 deletions.
64 changes: 63 additions & 1 deletion crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::sql_expr::{parse_sql_array, parse_sql_expr, process_join_constraint};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
Expand Down Expand Up @@ -749,6 +749,7 @@ impl SQLContext {
alias,
} => {
polars_ensure!(!(*lateral), ComputeError: "LATERAL not supported");

if let Some(alias) = alias {
let lf = self.execute_query_no_ctes(subquery)?;
self.table_map.insert(alias.name.value.clone(), lf.clone());
Expand All @@ -757,6 +758,67 @@ impl SQLContext {
polars_bail!(ComputeError: "derived tables must have aliases");
}
},
TableFactor::UNNEST {
alias,
array_exprs,
with_offset,
with_offset_alias: _,
} => {
if let Some(alias) = alias {
let table_name = alias.name.value.clone();
let column_names: Vec<Option<&str>> = alias
.columns
.iter()
.map(|c| {
if c.value.is_empty() {
None
} else {
Some(c.value.as_str())
}
})
.collect();

let column_values: Vec<Series> = array_exprs
.iter()
.map(|arr| parse_sql_array(arr, self))
.collect::<Result<_, _>>()?;

polars_ensure!(!column_names.is_empty(),
ComputeError:
"UNNEST table alias must also declare column names, eg: {} (a,b,c)", alias.name.to_string()
);
if column_names.len() != column_values.len() {
let plural = if column_values.len() > 1 { "s" } else { "" };
polars_bail!(
ComputeError:
"UNNEST table alias requires {} column name{}, found {}", column_values.len(), plural, column_names.len()
);
}
let column_series: Vec<Series> = column_values
.iter()
.zip(column_names.iter())
.map(|(s, name)| {
if let Some(name) = name {
s.clone().with_name(name)
} else {
s.clone()
}
})
.collect();

let lf = DataFrame::new(column_series)?.lazy();
if *with_offset {
// TODO: make a PR to `sqlparser-rs` to support 'ORDINALITY'
// (note that 'OFFSET' is BigQuery-specific syntax, not PostgreSQL)
polars_bail!(ComputeError: "UNNEST tables do not (yet) support WITH OFFSET/ORDINALITY");
}
self.table_map.insert(table_name.clone(), lf.clone());
Ok((table_name.clone(), lf))
} else {
polars_bail!(ComputeError: "UNNEST table must have an alias");
}
},

// Support bare table, optional with alias for now
_ => polars_bail!(ComputeError: "not yet implemented: {}", relation),
}
Expand Down
129 changes: 75 additions & 54 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use polars_core::export::regex;
use polars_core::prelude::*;
use polars_error::to_compute_err;
use polars_lazy::prelude::*;
use polars_ops::series::SeriesReshape;
use polars_plan::prelude::typed_lit;
use polars_plan::prelude::LiteralValue::Null;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -185,6 +186,28 @@ pub(crate) struct SQLExprVisitor<'a> {
}

impl SQLExprVisitor<'_> {
fn array_expr_to_series(&mut self, elements: &[SQLExpr]) -> PolarsResult<Series> {
let array_elements = elements
.iter()
.map(|e| match e {
SQLExpr::Value(v) => self.visit_any_value(v, None),
SQLExpr::UnaryOp { op, expr } => match expr.as_ref() {
SQLExpr::Value(v) => self.visit_any_value(v, Some(op)),
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)),
},
SQLExpr::Array(_) => {
// TODO: nested arrays (handle FnMut issues)
// let srs = self.array_expr_to_series(&[e.clone()])?;
// Ok(AnyValue::List(srs))
Err(polars_err!(ComputeError: "SQL interface does not yet support nested array literals:\n{:?}", e))
},
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)),
})
.collect::<PolarsResult<Vec<_>>>()?;

Series::from_any_values("", &array_elements, true)
}

fn visit_expr(&mut self, expr: &SQLExpr) -> PolarsResult<Expr> {
match expr {
SQLExpr::AllOp {
Expand All @@ -197,6 +220,7 @@ impl SQLExprVisitor<'_> {
compare_op,
right,
} => self.visit_any(left, compare_op, right),
SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None),
SQLExpr::ArrayAgg(expr) => self.visit_arr_agg(expr),
SQLExpr::Between {
expr,
Expand All @@ -220,7 +244,12 @@ impl SQLExprVisitor<'_> {
expr,
list,
negated,
} => self.visit_in_list(expr, list, *negated),
} => {
let expr = self.visit_expr(expr)?;
let elems = self.visit_array_expr(list, false, Some(&expr))?;
let is_in = expr.is_in(elems);
Ok(if *negated { is_in.not() } else { is_in })
},
SQLExpr::InSubquery {
expr,
subquery,
Expand Down Expand Up @@ -615,6 +644,38 @@ impl SQLExprVisitor<'_> {
}
}

/// Visit a SQL `ARRAY` list (including `IN` values).
fn visit_array_expr(
&mut self,
elements: &[SQLExpr],
result_as_element: bool,
dtype_expr_match: Option<&Expr>,
) -> PolarsResult<Expr> {
let mut elems = self.array_expr_to_series(elements)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if let (Some(Expr::Column(name)), Some(schema)) =
(dtype_expr_match, self.active_schema.as_ref())
{
if elems.dtype() == &DataType::String {
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
schema.get(name)
{
elems = elems.strict_cast(&schema.get(name).unwrap().clone())?;
}
}
}
// if we are parsing the list as an element in a series, implode.
// otherwise, return the series as-is.
let res = if result_as_element {
elems.implode()?.into_series()
} else {
elems
};
Ok(lit(res))
}

/// Visit a SQL `CAST` or `TRY_CAST` expression.
///
/// e.g. `CAST(col AS INT)`, `col::int4`, or `TRY_CAST(col AS VARCHAR)`,
Expand Down Expand Up @@ -810,59 +871,6 @@ impl SQLExprVisitor<'_> {
Ok(base.implode())
}

/// Visit a SQL `IN` expression
fn visit_in_list(
&mut self,
expr: &SQLExpr,
list: &[SQLExpr],
negated: bool,
) -> PolarsResult<Expr> {
let expr = self.visit_expr(expr)?;
let list = list
.iter()
.map(|e| {
if let SQLExpr::Value(v) = e {
let av = self.visit_any_value(v, None)?;
Ok(av)
} else if let SQLExpr::UnaryOp {op, expr} = e {
match expr.as_ref() {
SQLExpr::Value(v) => {
let av = self.visit_any_value(v, Some(op))?;
Ok(av)
},
_ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
}else{
Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e))
}
})
.collect::<PolarsResult<Vec<_>>>()?;

let mut s = Series::from_any_values("", &list, true)?;

// handle implicit temporal strings, eg: "dt IN ('2024-04-30','2024-05-01')".
// (not yet as versatile as the temporal string conversions in visit_binary_op)
if s.dtype() == &DataType::String {
// handle implicit temporal string comparisons, eg: "dt >= '2024-04-30'"
if let Expr::Column(name) = &expr {
if self.active_schema.is_some() {
let schema = self.active_schema.as_ref().unwrap();
let left_dtype = schema.get(name);
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
left_dtype
{
s = s.strict_cast(&left_dtype.unwrap().clone())?;
}
}
}
}
if negated {
Ok(expr.is_in(lit(s)).not())
} else {
Ok(expr.is_in(lit(s)))
}
}

/// Visit a SQL subquery inside and `IN` expression.
fn visit_in_subquery(
&mut self,
Expand Down Expand Up @@ -1115,6 +1123,19 @@ pub(crate) fn parse_sql_expr(
visitor.visit_expr(expr)
}

pub(crate) fn parse_sql_array(expr: &SQLExpr, ctx: &mut SQLContext) -> PolarsResult<Series> {
match expr {
SQLExpr::Array(arr) => {
let mut visitor = SQLExprVisitor {
ctx,
active_schema: None,
};
visitor.array_expr_to_series(arr.elem.as_slice())
},
_ => polars_bail!(ComputeError: "Expected array expression, found {:?}", expr),
}
}

fn parse_extract(expr: Expr, field: &DateTimeField) -> PolarsResult<Expr> {
Ok(match field {
DateTimeField::Millennium => expr.dt().millennium(),
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,23 @@ fn test_array_to_string() {
.unwrap();
assert!(df_sql.equals(&df_expected));
}

#[test]
fn test_array_literal() {
let mut context = SQLContext::new();
context.register("df", DataFrame::empty().lazy());

let sql = "SELECT [100,200,300] AS arr FROM df";
let df_sql = context.execute(sql).unwrap().collect().unwrap();
let df_expected = df! {
"arr" => &[100i64, 200, 300],
}
.unwrap()
.lazy()
.select(&[col("arr").implode()])
.collect()
.unwrap();

assert!(df_sql.equals(&df_expected));
assert!(df_sql.height() == 1);
}
97 changes: 97 additions & 0 deletions py-polars/tests/unit/sql/test_array.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from __future__ import annotations

import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


def test_array_literals() -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
res = ctx.execute(
"""
SELECT
a1, a2, ARRAY_AGG(a1) AS a3, ARRAY_AGG(a2) AS a4
FROM (
SELECT
[10,20,30] AS a1,
['a','b','c'] AS a2,
FROM df
) tbl
"""
)
assert_frame_equal(
res,
pl.DataFrame(
{
"a1": [[10, 20, 30]],
"a2": [["a", "b", "c"]],
"a3": [[[10, 20, 30]]],
"a4": [[["a", "b", "c"]]],
}
),
)


def test_array_to_string() -> None:
data = {"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}
res = pl.DataFrame(data).sql(
Expand All @@ -25,3 +55,70 @@ def test_array_to_string() -> None:
}
),
)


@pytest.mark.parametrize(
"array_keyword",
["ARRAY", ""],
)
def test_unnest_table_function(array_keyword: str) -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
res = ctx.execute(
f"""
SELECT * FROM
UNNEST(
{array_keyword}[1, 2, 3, 4],
{array_keyword}['ww','xx','yy','zz'],
{array_keyword}[23.0, 24.5, 28.0, 27.5]
) AS tbl (x,y,z);
"""
)
assert_frame_equal(
res,
pl.DataFrame(
{
"x": [1, 2, 3, 4],
"y": ["ww", "xx", "yy", "zz"],
"z": [23.0, 24.5, 28.0, 27.5],
}
),
)


def test_unnest_table_function_errors() -> None:
with pl.SQLContext(df=None, eager=True) as ctx:
with pytest.raises(
ComputeError,
match=r'UNNEST table alias must also declare column names, eg: "frame data" \(a,b,c\)',
):
ctx.execute('SELECT * FROM UNNEST([1, 2, 3]) AS "frame data"')

with pytest.raises(
ComputeError,
match="UNNEST table alias requires 1 column name, found 2",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) AS tbl (a, b)")

with pytest.raises(
ComputeError,
match="UNNEST table alias requires 2 column names, found 1",
):
ctx.execute("SELECT * FROM UNNEST([1,2,3], [3,4,5]) AS tbl (a)")

with pytest.raises(
ComputeError,
match=r"UNNEST table must have an alias",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3])")

with pytest.raises(
ComputeError,
match=r"UNNEST tables do not \(yet\) support WITH OFFSET/ORDINALITY",
):
ctx.execute("SELECT * FROM UNNEST([1, 2, 3]) tbl (colx) WITH OFFSET")

with pytest.raises(
ComputeError,
match="SQL interface does not yet support nested array literals",
):
pl.sql_expr("[[1,2,3]] AS nested")

0 comments on commit 1cafcbc

Please sign in to comment.