Skip to content

Commit

Permalink
feat(rust, python, cli)!: Add ignore_nulls for list.join
Browse files Browse the repository at this point in the history
Co-authored-by: alexander-beedie <alexander.m.beedie@icloud.com>
  • Loading branch information
reswqa and alexander-beedie committed Jan 19, 2024
1 parent 62ecdd3 commit 198fa2d
Show file tree
Hide file tree
Showing 12 changed files with 152 additions and 41 deletions.
40 changes: 29 additions & 11 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,47 +78,60 @@ fn cast_rhs(
pub trait ListNameSpaceImpl: AsList {
/// In case the inner dtype [`DataType::String`], the individual items will be joined into a
/// single string separated by `separator`.
fn lst_join(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn lst_join(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
match ca.inner_dtype() {
DataType::String => match separator.len() {
1 => match separator.get(0) {
Some(separator) => self.join_literal(separator),
Some(separator) => self.join_literal(separator, ignore_nulls),
_ => Ok(StringChunked::full_null(ca.name(), ca.len())),
},
_ => self.join_many(separator),
_ => self.join_many(separator, ignore_nulls),
},
dt => polars_bail!(op = "`lst.join`", got = dt, expected = "String"),
}
}

fn join_literal(&self, separator: &str) -> PolarsResult<StringChunked> {
fn join_literal(&self, separator: &str, ignore_nulls: bool) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
let mut builder = StringChunkedBuilder::new(ca.name(), ca.len());

ca.for_each_amortized(|opt_s| {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}

fn join_many(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn join_many(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
Expand All @@ -129,19 +142,24 @@ pub trait ListNameSpaceImpl: AsList {
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
},
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub enum ListFunction {
Any,
#[cfg(feature = "list_any_all")]
All,
Join,
Join(bool),
#[cfg(feature = "dtype-array")]
ToArray(usize),
}
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ListFunction {
Any => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "list_any_all")]
All => mapper.with_dtype(DataType::Boolean),
Join => mapper.with_dtype(DataType::String),
Join(_) => mapper.with_dtype(DataType::String),
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
}
Expand Down Expand Up @@ -153,7 +153,7 @@ impl Display for ListFunction {
Any => "any",
#[cfg(feature = "list_any_all")]
All => "all",
Join => "join",
Join(_) => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
};
Expand Down Expand Up @@ -208,7 +208,7 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Any => map!(lst_any),
#[cfg(feature = "list_any_all")]
All => map!(lst_all),
Join => map_as_slice!(join),
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
}
Expand Down Expand Up @@ -553,10 +553,10 @@ pub(super) fn lst_all(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_all()
}

pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult<Series> {
let ca = s[0].list()?;
let separator = s[1].str()?;
Ok(ca.lst_join(separator)?.into_series())
Ok(ca.lst_join(separator, ignore_nulls)?.into_series())
}

#[cfg(feature = "dtype-array")]
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ impl ListNameSpace {
/// Join all string items in a sublist and place a separator between them.
/// # Error
/// This errors if inner type of list `!= DataType::String`.
pub fn join(self, separator: Expr) -> Expr {
pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Join),
FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)),
&[separator],
false,
false,
Expand Down
13 changes: 7 additions & 6 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans"
arrow = { workspace = true }
polars-core = { workspace = true }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-plan = { workspace = true }

hex = { workspace = true }
Expand All @@ -27,14 +27,15 @@ sqlparser = { workspace = true }
polars-core = { workspace = true, features = ["fmt"] }

[features]
default = []
nightly = []
csv = ["polars-lazy/csv"]
ipc = ["polars-lazy/ipc"]
json = ["polars-lazy/json"]
default = []
binary_encoding = ["polars-lazy/binary_encoding"]
diagonal_concat = ["polars-lazy/diagonal_concat"]
dtype-decimal = ["polars-lazy/dtype-decimal"]
ipc = ["polars-lazy/ipc"]
list_eval = ["polars-lazy/list_eval"]
parquet = ["polars-lazy/parquet"]
semi_anti_join = ["polars-lazy/semi_anti_join"]
diagonal_concat = ["polars-lazy/diagonal_concat"]
binary_encoding = ["polars-lazy/binary_encoding"]
timezones = ["polars-lazy/timezones"]
nightly = []
49 changes: 40 additions & 9 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
use polars_plan::dsl::{coalesce, concat_str, len, when};
use polars_plan::logical_plan::LiteralValue;
#[cfg(feature = "list_eval")]
use polars_plan::prelude::col;
use polars_plan::prelude::LiteralValue::Null;
use polars_plan::prelude::{lit, StrptimeOptions};
use sqlparser::ast::{
Expand Down Expand Up @@ -520,7 +524,8 @@ pub(crate) enum PolarsSQLFunctions {
/// SQL 'array_to_string' function
/// Takes all elements of the array and joins them into one string.
/// ```sql
/// SELECT ARRAY_TO_STRING(column_1, ', ') from df;
/// SELECT ARRAY_TO_STRING(column_1, ',') from df;
/// SELECT ARRAY_TO_STRING(column_1, ',', 'n/a') from df;
/// ```
ArrayToString,
/// SQL 'array_get' function
Expand Down Expand Up @@ -846,7 +851,7 @@ impl SQLFunctionVisitor<'_> {
} else {
self.try_visit_variadic(|exprs: &[Expr]| {
match &exprs[0] {
Expr::Literal(LiteralValue::String(s)) => Ok(concat_str(&exprs[1..], s)),
Expr::Literal(LiteralValue::String(sep)) => Ok(concat_str(&exprs[1..], sep)),
_ => polars_bail!(InvalidOperation: "ConcatWS 'separator' must be a literal string; found {:?}", exprs[0]),
}
})
Expand Down Expand Up @@ -969,9 +974,23 @@ impl SQLFunctionVisitor<'_> {
ArrayMin => self.visit_unary(|e| e.list().min()),
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
ArraySum => self.visit_unary(|e| e.list().sum()),
ArrayToString => self.try_visit_binary(|e, s| {
Ok(e.list().join(s))
}),
ArrayToString => match function.args.len() {
2 => self.try_visit_binary(|e, sep| { Ok(e.list().join(sep, true)) }),
#[cfg(feature = "list_eval")]
3 => self.try_visit_ternary(|e, sep, null_value| {
match null_value {
Expr::Literal(LiteralValue::String(v)) => {
Ok(if v.is_empty() {
e.list().join(sep, true)
} else {
e.list().eval(col("").fill_null(lit(v)), false).list().join(sep, false)
})
},
_ => polars_bail!(InvalidOperation: "Invalid null value for ArrayToString: {}", function.args[2]),
}
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for ArrayToString: {}", function.args.len()),
}
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Udf(func_name) => self.visit_udf(&func_name)
Expand Down Expand Up @@ -1169,10 +1188,7 @@ impl SQLFunctionVisitor<'_> {
.iter()
.map(|o| {
let e = parse_sql_expr(&o.expr, self.ctx)?;
match o.asc {
Some(b) => Ok(e.sort(!b)),
None => Ok(e),
}
Ok(o.asc.map_or(e.clone(), |b| e.sort(!b)))
})
.collect::<PolarsResult<Vec<_>>>()?;
expr.over(exprs)
Expand Down Expand Up @@ -1237,6 +1253,21 @@ impl FromSQLExpr for f64 {
}
}

impl FromSQLExpr for bool {
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::Boolean(v) => Ok(*v),
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", v),
},
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", expr),
}
}
}

impl FromSQLExpr for String {
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
where
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn array_to_string() {
.lazy()
.group_by([col("b")])
.agg([col("a")])
.select(&[col("b"), col("a").list().join(lit(", ")).alias("as")])
.select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")])
.sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true)
.collect()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ list_any_all = ["polars-lazy?/list_any_all"]
list_count = ["polars-ops/list_count", "polars-lazy?/list_count"]
array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"]
list_drop_nulls = ["polars-lazy?/list_drop_nulls"]
list_eval = ["polars-lazy?/list_eval"]
list_eval = ["polars-lazy?/list_eval", "polars-sql?/list_eval"]
list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"]
list_sample = ["polars-lazy?/list_sample"]
list_sets = ["polars-lazy?/list_sets"]
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def contains(
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.list_contains(item))

def join(self, separator: IntoExprColumn) -> Expr:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr:
"""
Join all string items in a sublist and place a separator between them.
Expand All @@ -566,6 +566,11 @@ def join(self, separator: IntoExprColumn) -> Expr:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down Expand Up @@ -601,7 +606,7 @@ def join(self, separator: IntoExprColumn) -> Expr:
└─────────────────┴───────────┴───────┘
"""
separator = parse_as_expression(separator, str_as_lit=True)
return wrap_expr(self._pyexpr.list_join(separator))
return wrap_expr(self._pyexpr.list_join(separator, ignore_nulls))

def arg_min(self) -> Expr:
"""
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def gather(
def __getitem__(self, item: int) -> Series:
return self.get(item)

def join(self, separator: IntoExprColumn) -> Series:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Series:
"""
Join all string items in a sublist and place a separator between them.
Expand All @@ -396,6 +396,11 @@ def join(self, separator: IntoExprColumn) -> Series:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the sub-list contains any null values, the output is ``None``.
Returns
-------
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ impl PyExpr {
self.inner.clone().list().get(index.inner).into()
}

fn list_join(&self, separator: PyExpr) -> Self {
self.inner.clone().list().join(separator.inner).into()
fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
self.inner
.clone()
.list()
.join(separator.inner, ignore_nulls)
.into()
}

fn list_len(&self) -> Self {
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def test_list_join() -> None:
out = df.select(pl.col("a").list.join(pl.col("separator")))
assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "g", "", None]}

# test ignore_nulls argument
df = pl.DataFrame(
{
"a": [["a", None, "b", None], None, [None, None], ["c", "d"], []],
"separator": ["-", "&", " ", "@", "/"],
}
)
# ignore nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d", ""]}
# propagate nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d", ""]}


def test_list_arr_empty() -> None:
df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]})
Expand Down
Loading

0 comments on commit 198fa2d

Please sign in to comment.