Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python, cli): Add ignore_nulls for list.join #13701

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,47 +78,60 @@ fn cast_rhs(
pub trait ListNameSpaceImpl: AsList {
/// In case the inner dtype [`DataType::String`], the individual items will be joined into a
/// single string separated by `separator`.
fn lst_join(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn lst_join(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
match ca.inner_dtype() {
DataType::String => match separator.len() {
1 => match separator.get(0) {
Some(separator) => self.join_literal(separator),
Some(separator) => self.join_literal(separator, ignore_nulls),
_ => Ok(StringChunked::full_null(ca.name(), ca.len())),
},
_ => self.join_many(separator),
_ => self.join_many(separator, ignore_nulls),
},
dt => polars_bail!(op = "`lst.join`", got = dt, expected = "String"),
}
}

fn join_literal(&self, separator: &str) -> PolarsResult<StringChunked> {
fn join_literal(&self, separator: &str, ignore_nulls: bool) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
let mut builder = StringChunkedBuilder::new(ca.name(), ca.len());

ca.for_each_amortized(|opt_s| {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}

fn join_many(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn join_many(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
Expand All @@ -129,19 +142,24 @@ pub trait ListNameSpaceImpl: AsList {
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
},
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub enum ListFunction {
Any,
#[cfg(feature = "list_any_all")]
All,
Join,
Join(bool),
#[cfg(feature = "dtype-array")]
ToArray(usize),
}
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ListFunction {
Any => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "list_any_all")]
All => mapper.with_dtype(DataType::Boolean),
Join => mapper.with_dtype(DataType::String),
Join(_) => mapper.with_dtype(DataType::String),
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
}
Expand Down Expand Up @@ -153,7 +153,7 @@ impl Display for ListFunction {
Any => "any",
#[cfg(feature = "list_any_all")]
All => "all",
Join => "join",
Join(_) => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
};
Expand Down Expand Up @@ -208,7 +208,7 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Any => map!(lst_any),
#[cfg(feature = "list_any_all")]
All => map!(lst_all),
Join => map_as_slice!(join),
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
}
Expand Down Expand Up @@ -553,10 +553,10 @@ pub(super) fn lst_all(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_all()
}

pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult<Series> {
let ca = s[0].list()?;
let separator = s[1].str()?;
Ok(ca.lst_join(separator)?.into_series())
Ok(ca.lst_join(separator, ignore_nulls)?.into_series())
}

#[cfg(feature = "dtype-array")]
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ impl ListNameSpace {
/// Join all string items in a sublist and place a separator between them.
/// # Error
/// This errors if inner type of list `!= DataType::String`.
pub fn join(self, separator: Expr) -> Expr {
pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Join),
FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)),
&[separator],
false,
false,
Expand Down
13 changes: 7 additions & 6 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans"
arrow = { workspace = true }
polars-core = { workspace = true }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-plan = { workspace = true }

hex = { workspace = true }
Expand All @@ -27,14 +27,15 @@ sqlparser = { workspace = true }
polars-core = { workspace = true, features = ["fmt"] }

[features]
default = []
nightly = []
csv = ["polars-lazy/csv"]
ipc = ["polars-lazy/ipc"]
json = ["polars-lazy/json"]
default = []
binary_encoding = ["polars-lazy/binary_encoding"]
diagonal_concat = ["polars-lazy/diagonal_concat"]
dtype-decimal = ["polars-lazy/dtype-decimal"]
ipc = ["polars-lazy/ipc"]
list_eval = ["polars-lazy/list_eval"]
parquet = ["polars-lazy/parquet"]
semi_anti_join = ["polars-lazy/semi_anti_join"]
diagonal_concat = ["polars-lazy/diagonal_concat"]
binary_encoding = ["polars-lazy/binary_encoding"]
timezones = ["polars-lazy/timezones"]
nightly = []
47 changes: 39 additions & 8 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
use polars_plan::dsl::{coalesce, concat_str, len, when};
use polars_plan::logical_plan::LiteralValue;
#[cfg(feature = "list_eval")]
use polars_plan::prelude::col;
use polars_plan::prelude::LiteralValue::Null;
use polars_plan::prelude::{lit, StrptimeOptions};
use sqlparser::ast::{
Expand Down Expand Up @@ -520,7 +524,8 @@ pub(crate) enum PolarsSQLFunctions {
/// SQL 'array_to_string' function
/// Takes all elements of the array and joins them into one string.
/// ```sql
/// SELECT ARRAY_TO_STRING(column_1, ', ') from df;
/// SELECT ARRAY_TO_STRING(column_1, ',') from df;
/// SELECT ARRAY_TO_STRING(column_1, ',', 'n/a') from df;
/// ```
ArrayToString,
/// SQL 'array_get' function
Expand Down Expand Up @@ -969,9 +974,23 @@ impl SQLFunctionVisitor<'_> {
ArrayMin => self.visit_unary(|e| e.list().min()),
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
ArraySum => self.visit_unary(|e| e.list().sum()),
ArrayToString => self.try_visit_binary(|e, s| {
Ok(e.list().join(s))
}),
ArrayToString => match function.args.len() {
2 => self.try_visit_binary(|e, sep| { Ok(e.list().join(sep, true)) }),
#[cfg(feature = "list_eval")]
3 => self.try_visit_ternary(|e, sep, null_value| {
match null_value {
Expr::Literal(LiteralValue::String(v)) => {
Ok(if v.is_empty() {
e.list().join(sep, true)
} else {
e.list().eval(col("").fill_null(lit(v)), false).list().join(sep, false)
})
},
_ => polars_bail!(InvalidOperation: "Invalid null value for ArrayToString: {}", function.args[2]),
}
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for ArrayToString: {}", function.args.len()),
}
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Udf(func_name) => self.visit_udf(&func_name)
Expand Down Expand Up @@ -1169,10 +1188,7 @@ impl SQLFunctionVisitor<'_> {
.iter()
.map(|o| {
let e = parse_sql_expr(&o.expr, self.ctx)?;
match o.asc {
Some(b) => Ok(e.sort(!b)),
None => Ok(e),
}
Ok(o.asc.map_or(e.clone(), |b| e.sort(!b)))
})
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(exprs)
Expand Down Expand Up @@ -1237,6 +1253,21 @@ impl FromSQLExpr for f64 {
}
}

impl FromSQLExpr for bool {
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::Boolean(v) => Ok(*v),
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", v),
},
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", expr),
}
}
}

impl FromSQLExpr for String {
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
where
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn array_to_string() {
.lazy()
.group_by([col("b")])
.agg([col("a")])
.select(&[col("b"), col("a").list().join(lit(", ")).alias("as")])
.select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")])
.sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true)
.collect()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ list_any_all = ["polars-lazy?/list_any_all"]
list_count = ["polars-ops/list_count", "polars-lazy?/list_count"]
array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"]
list_drop_nulls = ["polars-lazy?/list_drop_nulls"]
list_eval = ["polars-lazy?/list_eval"]
list_eval = ["polars-lazy?/list_eval", "polars-sql?/list_eval"]
list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"]
list_sample = ["polars-lazy?/list_sample"]
list_sets = ["polars-lazy?/list_sets"]
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def contains(
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.list_contains(item))

def join(self, separator: IntoExprColumn) -> Expr:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr:
"""
Join all string items in a sublist and place a separator between them.

Expand All @@ -566,6 +566,11 @@ def join(self, separator: IntoExprColumn) -> Expr:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).

If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.

Returns
-------
Expand Down Expand Up @@ -601,7 +606,7 @@ def join(self, separator: IntoExprColumn) -> Expr:
└─────────────────┴───────────┴───────┘
"""
separator = parse_as_expression(separator, str_as_lit=True)
return wrap_expr(self._pyexpr.list_join(separator))
return wrap_expr(self._pyexpr.list_join(separator, ignore_nulls))

def arg_min(self) -> Expr:
"""
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def gather(
def __getitem__(self, item: int) -> Series:
return self.get(item)

def join(self, separator: IntoExprColumn) -> Series:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Series:
"""
Join all string items in a sublist and place a separator between them.

Expand All @@ -396,6 +396,11 @@ def join(self, separator: IntoExprColumn) -> Series:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).

If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.

Returns
-------
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ impl PyExpr {
self.inner.clone().list().get(index.inner).into()
}

fn list_join(&self, separator: PyExpr) -> Self {
self.inner.clone().list().join(separator.inner).into()
fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
self.inner
.clone()
.list()
.join(separator.inner, ignore_nulls)
.into()
}

fn list_len(&self) -> Self {
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def test_list_join() -> None:
out = df.select(pl.col("a").list.join(pl.col("separator")))
assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "g", "", None]}

# test ignore_nulls argument
df = pl.DataFrame(
{
"a": [["a", None, "b", None], None, [None, None], ["c", "d"], []],
"separator": ["-", "&", " ", "@", "/"],
}
)
# ignore nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d", ""]}
# propagate nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d", ""]}


def test_list_arr_empty() -> None:
df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]})
Expand Down
Loading