Skip to content

Commit

Permalink
feat(rust, python, cli)!: Add ignore_nulls for list.join
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Jan 13, 2024
1 parent b27fe94 commit 8549df6
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 32 deletions.
40 changes: 29 additions & 11 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ fn cast_rhs(
pub trait ListNameSpaceImpl: AsList {
/// In case the inner dtype [`DataType::String`], the individual items will be joined into a
/// single string separated by `separator`.
fn lst_join(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn lst_join(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
match ca.inner_dtype() {
DataType::String => match separator.len() {
1 => match separator.get(0) {
Some(separator) => self.join_literal(separator),
Some(separator) => self.join_literal(separator, ignore_nulls),
_ => Ok(StringChunked::full_null(ca.name(), ca.len())),
},
_ => self.join_many(separator),
_ => self.join_many(separator, ignore_nulls),
},
dt => polars_bail!(op = "`lst.join`", got = dt, expected = "String"),
}
}

fn join_literal(&self, separator: &str) -> PolarsResult<StringChunked> {
fn join_literal(&self, separator: &str, ignore_nulls: bool) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
Expand All @@ -103,26 +107,35 @@ pub trait ListNameSpaceImpl: AsList {
);

ca.for_each_amortized(|opt_s| {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}

fn join_many(&self, separator: &StringChunked) -> PolarsResult<StringChunked> {
fn join_many(
&self,
separator: &StringChunked,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
let ca = self.as_list();
// used to amortize heap allocs
let mut buf = String::with_capacity(128);
Expand All @@ -134,19 +147,24 @@ pub trait ListNameSpaceImpl: AsList {
.zip(separator)
.for_each(|(opt_s, opt_sep)| match opt_sep {
Some(separator) => {
let opt_val = opt_s.map(|s| {
let opt_val = opt_s.and_then(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().str().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

if ca.null_count() != 0 && !ignore_nulls {
return None;
}

let iter = ca.into_iter().flatten();

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
Some(&buf[..buf.len().saturating_sub(separator.len())])
});
builder.append_option(opt_val)
},
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub enum ListFunction {
Any,
#[cfg(feature = "list_any_all")]
All,
Join,
Join(bool),
#[cfg(feature = "dtype-array")]
ToArray(usize),
}
Expand Down Expand Up @@ -88,7 +88,7 @@ impl ListFunction {
Any => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "list_any_all")]
All => mapper.with_dtype(DataType::Boolean),
Join => mapper.with_dtype(DataType::String),
Join(_) => mapper.with_dtype(DataType::String),
#[cfg(feature = "dtype-array")]
ToArray(width) => mapper.try_map_dtype(|dt| map_list_dtype_to_array_dtype(dt, *width)),
}
Expand Down Expand Up @@ -153,7 +153,7 @@ impl Display for ListFunction {
Any => "any",
#[cfg(feature = "list_any_all")]
All => "all",
Join => "join",
Join(_) => "join",
#[cfg(feature = "dtype-array")]
ToArray(_) => "to_array",
};
Expand Down Expand Up @@ -208,7 +208,7 @@ impl From<ListFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Any => map!(lst_any),
#[cfg(feature = "list_any_all")]
All => map!(lst_all),
Join => map_as_slice!(join),
Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),
#[cfg(feature = "dtype-array")]
ToArray(width) => map!(to_array, width),
}
Expand Down Expand Up @@ -553,10 +553,10 @@ pub(super) fn lst_all(s: &Series) -> PolarsResult<Series> {
s.list()?.lst_all()
}

pub(super) fn join(s: &[Series]) -> PolarsResult<Series> {
pub(super) fn join(s: &[Series], ignore_nulls: bool) -> PolarsResult<Series> {
let ca = s[0].list()?;
let separator = s[1].str()?;
Ok(ca.lst_join(separator)?.into_series())
Ok(ca.lst_join(separator, ignore_nulls)?.into_series())
}

#[cfg(feature = "dtype-array")]
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ impl ListNameSpace {
/// Join all string items in a sublist and place a separator between them.
/// # Error
/// This errors if inner type of list `!= DataType::String`.
pub fn join(self, separator: Expr) -> Expr {
pub fn join(self, separator: Expr, ignore_nulls: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Join),
FunctionExpr::ListExpr(ListFunction::Join(ignore_nulls)),
&[separator],
false,
false,
Expand Down
40 changes: 33 additions & 7 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,14 @@ pub(crate) enum PolarsSQLFunctions {
Explode,
/// SQL 'array_to_string' function
/// Takes all elements of the array and joins them into one string.
/// It will ignore null values by default.
/// ```sql
/// SELECT ARRAY_TO_STRING(column_1, ', ') from df;
/// ```
/// You can set the third argument(ignore_nulls) to false to propagate null values.
/// ```sql
/// SELECT ARRAY_TO_STRING(column_1, ', ', false) from df;
/// ```
ArrayToString,
/// SQL 'array_get' function
/// Returns the value at the given index in the array.
Expand Down Expand Up @@ -969,9 +974,15 @@ impl SQLFunctionVisitor<'_> {
ArrayMin => self.visit_unary(|e| e.list().min()),
ArrayReverse => self.visit_unary(|e| e.list().reverse()),
ArraySum => self.visit_unary(|e| e.list().sum()),
ArrayToString => self.try_visit_binary(|e, s| {
Ok(e.list().join(s))
}),
ArrayToString => match function.args.len(){
2 => self.try_visit_binary(|e, s| {
Ok(e.list().join(s, true))
}),
3 => self.try_visit_ternary(|e, s, ignore_nulls| {
Ok(e.list().join(s, ignore_nulls))
}),
_ => polars_bail!(InvalidOperation: "Invalid number of arguments for ArrayToString: {}", function.args.len()),
}
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Udf(func_name) => self.visit_udf(&func_name)
Expand Down Expand Up @@ -1108,17 +1119,17 @@ impl SQLFunctionVisitor<'_> {
f(&expr_args)
}

fn try_visit_ternary<Arg: FromSQLExpr>(
fn try_visit_ternary<Arg1: FromSQLExpr, Arg2: FromSQLExpr>(
&mut self,
f: impl Fn(Expr, Arg, Arg) -> PolarsResult<Expr>,
f: impl Fn(Expr, Arg1, Arg2) -> PolarsResult<Expr>,
) -> PolarsResult<Expr> {
let args = extract_args(self.func);
match args.as_slice() {
[FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] =>
{
let expr1 = parse_sql_expr(sql_expr1, self.ctx)?;
let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?;
let expr3 = Arg::from_sql_expr(sql_expr3, self.ctx)?;
let expr2 = Arg1::from_sql_expr(sql_expr2, self.ctx)?;
let expr3 = Arg2::from_sql_expr(sql_expr3, self.ctx)?;
f(expr1, expr2, expr3)
},
_ => self.not_supported_error(),
Expand Down Expand Up @@ -1237,6 +1248,21 @@ impl FromSQLExpr for f64 {
}
}

impl FromSQLExpr for bool {
fn from_sql_expr(expr: &SQLExpr, _ctx: &mut SQLContext) -> PolarsResult<Self>
where
Self: Sized,
{
match expr {
SQLExpr::Value(v) => match v {
SQLValue::Boolean(v) => Ok(*v),
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", v),
},
_ => polars_bail!(ComputeError: "can't parse boolean {:?}", expr),
}
}
}

impl FromSQLExpr for String {
fn from_sql_expr(expr: &SQLExpr, _: &mut SQLContext) -> PolarsResult<Self>
where
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn array_to_string() {
.lazy()
.group_by([col("b")])
.agg([col("a")])
.select(&[col("b"), col("a").list().join(lit(", ")).alias("as")])
.select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")])
.sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true)
.collect()
.unwrap();
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def contains(
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.list_contains(item))

def join(self, separator: IntoExprColumn) -> Expr:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Expr:
"""
Join all string items in a sublist and place a separator between them.
Expand All @@ -564,6 +564,11 @@ def join(self, separator: IntoExprColumn) -> Expr:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the column contains any null values, the output is ``None``.
Returns
-------
Expand Down Expand Up @@ -599,7 +604,7 @@ def join(self, separator: IntoExprColumn) -> Expr:
└─────────────────┴───────────┴───────┘
"""
separator = parse_as_expression(separator, str_as_lit=True)
return wrap_expr(self._pyexpr.list_join(separator))
return wrap_expr(self._pyexpr.list_join(separator, ignore_nulls))

def arg_min(self) -> Expr:
"""
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def gather(
def __getitem__(self, item: int) -> Series:
return self.get(item)

def join(self, separator: IntoExprColumn) -> Series:
def join(self, separator: IntoExprColumn, *, ignore_nulls: bool = True) -> Series:
"""
Join all string items in a sublist and place a separator between them.
Expand All @@ -393,6 +393,11 @@ def join(self, separator: IntoExprColumn) -> Series:
----------
separator
string to separate the items with
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the column contains any null values, the output is ``None``.
Returns
-------
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ impl PyExpr {
self.inner.clone().list().get(index.inner).into()
}

fn list_join(&self, separator: PyExpr) -> Self {
self.inner.clone().list().join(separator.inner).into()
fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self {
self.inner
.clone()
.list()
.join(separator.inner, ignore_nulls)
.into()
}

fn list_len(&self) -> Self {
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def test_list_join() -> None:
out = df.select(pl.col("a").list.join(pl.col("separator")))
assert out.to_dict(as_series=False) == {"a": ["ab&c&d", None, "g", "", None]}

# test ignore_nulls argument
df = pl.DataFrame(
{
"a": [["a", None, "b", None], None, [None, None], ["c", "d"], []],
"separator": ["-", "&", " ", "@", "/"],
}
)
# ignore nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=True))
assert out.to_dict(as_series=False) == {"a": ["a-b", None, "", "c@d", ""]}
# propagate nulls
out = df.select(pl.col("a").list.join("-", ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c-d", ""]}
out = df.select(pl.col("a").list.join(pl.col("separator"), ignore_nulls=False))
assert out.to_dict(as_series=False) == {"a": [None, None, None, "c@d", ""]}


def test_list_arr_empty() -> None:
df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]})
Expand Down

0 comments on commit 8549df6

Please sign in to comment.