From 5e8ac6d2c1bfa9e7e59cb8b9b2a0ccc671507901 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sun, 21 Jan 2024 23:17:46 +0800 Subject: [PATCH] feat(rust, python, cli): Add `ignore_nulls` for `list.join` (#13701) Co-authored-by: alexander-beedie --- .../src/chunked_array/list/namespace.rs | 40 +++++++++++----- .../polars-plan/src/dsl/function_expr/list.rs | 12 ++--- crates/polars-plan/src/dsl/list.rs | 4 +- crates/polars-sql/Cargo.toml | 13 ++--- crates/polars-sql/src/functions.rs | 47 +++++++++++++++---- crates/polars-sql/tests/functions_string.rs | 2 +- crates/polars/Cargo.toml | 2 +- py-polars/polars/expr/list.py | 9 +++- py-polars/polars/series/list.py | 7 ++- py-polars/src/expr/list.rs | 8 +++- py-polars/tests/unit/namespaces/test_list.py | 18 +++++++ py-polars/tests/unit/sql/test_array.py | 29 ++++++++++++ 12 files changed, 151 insertions(+), 40 deletions(-) create mode 100644 py-polars/tests/unit/sql/test_array.py diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index e53fb830361d..37027e637a14 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -78,32 +78,41 @@ fn cast_rhs( pub trait ListNameSpaceImpl: AsList { /// In case the inner dtype [`DataType::String`], the individual items will be joined into a /// single string separated by `separator`. - fn lst_join(&self, separator: &StringChunked) -> PolarsResult { + fn lst_join( + &self, + separator: &StringChunked, + ignore_nulls: bool, + ) -> PolarsResult { let ca = self.as_list(); match ca.inner_dtype() { DataType::String => match separator.len() { 1 => match separator.get(0) { - Some(separator) => self.join_literal(separator), + Some(separator) => self.join_literal(separator, ignore_nulls), _ => Ok(StringChunked::full_null(ca.name(), ca.len())), }, - _ => self.join_many(separator), + _ => self.join_many(separator, ignore_nulls), }, dt => polars_bail!(op = "`lst.join`", got = dt, expected = "String"), } } - fn join_literal(&self, separator: &str) -> PolarsResult { + fn join_literal(&self, separator: &str, ignore_nulls: bool) -> PolarsResult { let ca = self.as_list(); // used to amortize heap allocs let mut buf = String::with_capacity(128); let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); ca.for_each_amortized(|opt_s| { - let opt_val = opt_s.map(|s| { + let opt_val = opt_s.and_then(|s| { // make sure that we don't write values of previous iteration buf.clear(); let ca = s.as_ref().str().unwrap(); - let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + + let iter = ca.into_iter().flatten(); for val in iter { buf.write_str(val).unwrap(); @@ -111,14 +120,18 @@ pub trait ListNameSpaceImpl: AsList { } // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. - &buf[..buf.len().saturating_sub(separator.len())] + Some(&buf[..buf.len().saturating_sub(separator.len())]) }); builder.append_option(opt_val) }); Ok(builder.finish()) } - fn join_many(&self, separator: &StringChunked) -> PolarsResult { + fn join_many( + &self, + separator: &StringChunked, + ignore_nulls: bool, + ) -> PolarsResult { let ca = self.as_list(); // used to amortize heap allocs let mut buf = String::with_capacity(128); @@ -129,11 +142,16 @@ pub trait ListNameSpaceImpl: AsList { .zip(separator) .for_each(|(opt_s, opt_sep)| match opt_sep { Some(separator) => { - let opt_val = opt_s.map(|s| { + let opt_val = opt_s.and_then(|s| { // make sure that we don't write values of previous iteration buf.clear(); let ca = s.as_ref().str().unwrap(); - let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + if ca.null_count() != 0 && !ignore_nulls { + return None; + } + + let iter = ca.into_iter().flatten(); for val in iter { buf.write_str(val).unwrap(); @@ -141,7 +159,7 @@ pub trait ListNameSpaceImpl: AsList { } // last value should not have a separator, so slice that off // saturating sub because there might have been nothing written. - &buf[..buf.len().saturating_sub(separator.len())] + Some(&buf[..buf.len().saturating_sub(separator.len())]) }); builder.append_option(opt_val) }, diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 544a867a05d4..b26726baed12 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -47,7 +47,7 @@ pub enum ListFunction { Any, #[cfg(feature = "list_any_all")] All, - Join, + Join(bool), #[cfg(feature = "dtype-array")] ToArray(usize), } @@ -88,7 +88,7 @@ impl ListFunction { Any => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_any_all")] All => mapper.with_dtype(DataType::Boolean), - Join => mapper.with_dtype(DataType::String), + Join(_) => mapper.with_dtype(DataType::String), #[cfg(feature = "dtype-array")] ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)), } @@ -153,7 +153,7 @@ impl Display for ListFunction { Any => "any", #[cfg(feature = "list_any_all")] All => "all", - Join => "join", + Join(_) => "join", #[cfg(feature = "dtype-array")] ToArray(_) => "to_array", }; @@ -208,7 +208,7 @@ impl From for SpecialEq> { Any => map!(lst_any), #[cfg(feature = "list_any_all")] All => map!(lst_all), - Join => map_as_slice!(join), + Join(ignore_nulls) => map_as_slice!(join, ignore_nulls), #[cfg(feature = "dtype-array")] ToArray(width) => map!(to_array, width), } @@ -553,10 +553,10 @@ pub(super) fn lst_all(s: &Series) -> PolarsResult { s.list()?.lst_all() } -pub(super) fn join(s: &[Series]) -> PolarsResult { +pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult { let ca = s[0].list()?; let separator = s[1].str()?; - Ok(ca.lst_join(separator)?.into_series()) + Ok(ca.lst_join(separator, ignore_nulls)?.into_series()) } #[cfg(feature = "dtype-array")] diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 96c0303818a5..9f0219e566a8 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -169,9 +169,9 @@ impl ListNameSpace { /// Join all string items in a sublist and place a separator between them. /// # Error /// This errors if inner type of list `!= DataType::String`. - pub fn join(self, separator: Expr) -> Expr { + pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr { self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Join), + FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)), &[separator], false, false, diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 82331e185800..162eec9b5277 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" arrow = { workspace = true } polars-core = { workspace = true } polars-error = { workspace = true } -polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } +polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] } polars-plan = { workspace = true } hex = { workspace = true } @@ -27,14 +27,15 @@ sqlparser = { workspace = true } polars-core = { workspace = true, features = ["fmt"] } [features] +default = [] +nightly = [] csv = ["polars-lazy/csv"] +ipc = ["polars-lazy/ipc"] json = ["polars-lazy/json"] -default = [] +binary_encoding = ["polars-lazy/binary_encoding"] +diagonal_concat = ["polars-lazy/diagonal_concat"] dtype-decimal = ["polars-lazy/dtype-decimal"] -ipc = ["polars-lazy/ipc"] +list_eval = ["polars-lazy/list_eval"] parquet = ["polars-lazy/parquet"] semi_anti_join = ["polars-lazy/semi_anti_join"] -diagonal_concat = ["polars-lazy/diagonal_concat"] -binary_encoding = ["polars-lazy/binary_encoding"] timezones = ["polars-lazy/timezones"] -nightly = [] diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index db4adfa30917..6fc429d6aab7 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,7 +1,11 @@ use polars_core::prelude::{polars_bail, polars_err, PolarsResult}; use polars_lazy::dsl::Expr; +#[cfg(feature = "list_eval")] +use polars_lazy::dsl::ListNameSpaceExtension; use polars_plan::dsl::{coalesce, concat_str, len, when}; use polars_plan::logical_plan::LiteralValue; +#[cfg(feature = "list_eval")] +use polars_plan::prelude::col; use polars_plan::prelude::LiteralValue::Null; use polars_plan::prelude::{lit, StrptimeOptions}; use sqlparser::ast::{ @@ -520,7 +524,8 @@ pub(crate) enum PolarsSQLFunctions { /// SQL 'array_to_string' function /// Takes all elements of the array and joins them into one string. /// ```sql - /// SELECT ARRAY_TO_STRING(column_1, ', ') from df; + /// SELECT ARRAY_TO_STRING(column_1, ',') from df; + /// SELECT ARRAY_TO_STRING(column_1, ',', 'n/a') from df; /// ``` ArrayToString, /// SQL 'array_get' function @@ -969,9 +974,23 @@ impl SQLFunctionVisitor<'_> { ArrayMin => self.visit_unary(|e| e.list().min()), ArrayReverse => self.visit_unary(|e| e.list().reverse()), ArraySum => self.visit_unary(|e| e.list().sum()), - ArrayToString => self.try_visit_binary(|e, s| { - Ok(e.list().join(s)) - }), + ArrayToString => match function.args.len() { + 2 => self.try_visit_binary(|e, sep| { Ok(e.list().join(sep, true)) }), + #[cfg(feature = "list_eval")] + 3 => self.try_visit_ternary(|e, sep, null_value| { + match null_value { + Expr::Literal(LiteralValue::String(v)) => { + Ok(if v.is_empty() { + e.list().join(sep, true) + } else { + e.list().eval(col("").fill_null(lit(v)), false).list().join(sep, false) + }) + }, + _ => polars_bail!(InvalidOperation: "Invalid null value for ArrayToString: {}", function.args[2]), + } + }), + _ => polars_bail!(InvalidOperation: "Invalid number of arguments for ArrayToString: {}", function.args.len()), + } ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), Udf(func_name) => self.visit_udf(&func_name) @@ -1169,10 +1188,7 @@ impl SQLFunctionVisitor<'_> { .iter() .map(|o| { let e = parse_sql_expr(&o.expr, self.ctx)?; - match o.asc { - Some(b) => Ok(e.sort(!b)), - None => Ok(e), - } + Ok(o.asc.map_or(e.clone(), |b| e.sort(!b))) }) .collect::>>()?; expr.over(exprs) @@ -1237,6 +1253,21 @@ impl FromSQLExpr for f64 { } } +impl FromSQLExpr for bool { + fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult + where + Self: Sized, + { + match expr { + SQLExpr::Value(v) => match v { + SQLValue::Boolean(v) => Ok(*v), + _ => polars_bail!(ComputeError: "can't parse boolean {:?}", v), + }, + _ => polars_bail!(ComputeError: "can't parse boolean {:?}", expr), + } + } +} + impl FromSQLExpr for String { fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult where diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index ae4fb2641c8a..f6ea4314a5de 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -117,7 +117,7 @@ fn array_to_string() { .lazy() .group_by([col("b")]) .agg([col("a")]) - .select(&[col("b"), col("a").list().join(lit(", ")).alias("as")]) + .select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")]) .sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true) .collect() .unwrap(); diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 4705a4c0ac28..9f340dc9b5bd 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -157,7 +157,7 @@ list_any_all = ["polars-lazy?/list_any_all"] list_count = ["polars-ops/list_count", "polars-lazy?/list_count"] array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"] list_drop_nulls = ["polars-lazy?/list_drop_nulls"] -list_eval = ["polars-lazy?/list_eval"] +list_eval = ["polars-lazy?/list_eval", "polars-sql?/list_eval"] list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] list_sample = ["polars-lazy?/list_sample"] list_sets = ["polars-lazy?/list_sets"] diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 05a0148dcb75..0341e0fc08a8 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -556,7 +556,7 @@ def contains( item = parse_as_expression(item, str_as_lit=True) return wrap_expr(self._pyexpr.list_contains(item)) - def join(self, separator: IntoExprColumn) -> Expr: + def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr: """ Join all string items in a sublist and place a separator between them. @@ -566,6 +566,11 @@ def join(self, separator: IntoExprColumn) -> Expr: ---------- separator string to separate the items with + ignore_nulls + Ignore null values (default). + + If set to ``False``, null values will be propagated. + if the sub-list contains any null values, the output is ``None``. Returns ------- @@ -601,7 +606,7 @@ def join(self, separator: IntoExprColumn) -> Expr: └─────────────────┴───────────┴───────┘ """ separator = parse_as_expression(separator, str_as_lit=True) - return wrap_expr(self._pyexpr.list_join(separator)) + return wrap_expr(self._pyexpr.list_join(separator, ignore_nulls)) def arg_min(self) -> Expr: """ diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 56a52064838b..9df3cf61ac1e 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -386,7 +386,7 @@ def gather( def __getitem__(self, item: int) -> Series: return self.get(item) - def join(self, separator: IntoExprColumn) -> Series: + def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Series: """ Join all string items in a sublist and place a separator between them. @@ -396,6 +396,11 @@ def join(self, separator: IntoExprColumn) -> Series: ---------- separator string to separate the items with + ignore_nulls + Ignore null values (default). + + If set to ``False``, null values will be propagated. + if the sub-list contains any null values, the output is ``None``. Returns ------- diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 75c1b4a782e9..e17550f8ba9e 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -49,8 +49,12 @@ impl PyExpr { self.inner.clone().list().get(index.inner).into() } - fn list_join(&self, separator: PyExpr) -> Self { - self.inner.clone().list().join(separator.inner).into() + fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { + self.inner + .clone() + .list() + .join(separator.inner, ignore_nulls) + .into() } fn list_len(&self) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index cefea46ef69d..c86e52bbd42e 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -104,6 +104,24 @@ def test_list_join() -> None: out = df.select(pl.col("a").list.join(pl.col("separator"))) assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "g", "", None]} + # test ignore_nulls argument + df = pl.DataFrame( + { + "a": [["a", None, "b", None], None, [None, None], ["c", "d"], []], + "separator": ["-", "&", " ", "@", "/"], + } + ) + # ignore nulls + out = df.select(pl.col("a").list.join("-", ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d", ""]} + out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=True)) + assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d", ""]} + # propagate nulls + out = df.select(pl.col("a").list.join("-", ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d", ""]} + out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=False)) + assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d", ""]} + def test_list_arr_empty() -> None: df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]}) diff --git a/py-polars/tests/unit/sql/test_array.py b/py-polars/tests/unit/sql/test_array.py new file mode 100644 index 000000000000..a62cd6ffd984 --- /dev/null +++ b/py-polars/tests/unit/sql/test_array.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_array_to_string() -> None: + df = pl.DataFrame({"values": [["aa", "bb"], [None, "cc"], ["dd", None]]}) + + with pl.SQLContext(df=df, eager_execution=True) as ctx: + res = ctx.execute( + """ + SELECT + ARRAY_TO_STRING(values, '') AS v1, + ARRAY_TO_STRING(values, ':') AS v2, + ARRAY_TO_STRING(values, ':', 'NA') AS v3 + FROM df + """ + ) + assert_frame_equal( + res, + pl.DataFrame( + { + "v1": ["aabb", "cc", "dd"], + "v2": ["aa:bb", "cc", "dd"], + "v3": ["aa:bb", "NA:cc", "dd:NA"], + } + ), + )