Skip to content

Commit

Permalink
feat: Add ignore_nulls for pl.concat_str (pola-rs#13877)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored and r-brink committed Jan 24, 2024
1 parent 6319b12 commit 1b70f68
Show file tree
Hide file tree
Showing 14 changed files with 185 additions and 85 deletions.
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/projection_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn concat_str_regex_expansion() -> PolarsResult<()> {
]?
.lazy();
let out = df
.select([concat_str([col(r"^b_a_\d$")], ";").alias("concatenated")])
.select([concat_str([col(r"^b_a_\d$")], ";", false).alias("concatenated")])
.collect()?;
let s = out.column("concatenated")?;
assert_eq!(s, &Series::new("concatenated", ["a--;;", ";b--;", ";;c--"]));
Expand Down
37 changes: 27 additions & 10 deletions crates/polars-ops/src/chunked_array/strings/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,21 @@ enum ColumnIter<I, T> {
/// Horizontally concatenate all strings.
///
/// Each array should have length 1 or a length equal to the maximum length.
pub fn hor_str_concat(cas: &[&StringChunked], delimiter: &str) -> PolarsResult<StringChunked> {
pub fn hor_str_concat(
cas: &[&StringChunked],
delimiter: &str,
ignore_nulls: bool,
) -> PolarsResult<StringChunked> {
if cas.is_empty() {
return Ok(StringChunked::full_null("", 0));
}
if cas.len() == 1 {
return Ok(cas[0].clone());
let ca = cas[0];
return if !ignore_nulls || ca.null_count() == 0 {
Ok(ca.clone())
} else {
Ok(ca.apply_generic(|val| Some(val.unwrap_or(""))))
};
}

// Calculate the post-broadcast length and ensure everything is consistent.
Expand Down Expand Up @@ -93,23 +102,31 @@ pub fn hor_str_concat(cas: &[&StringChunked], delimiter: &str) -> PolarsResult<S
let mut buf = String::with_capacity(1024);
for _row in 0..len {
let mut has_null = false;
for (i, col) in cols.iter_mut().enumerate() {
if i > 0 {
buf.push_str(delimiter);
}

let mut found_not_null_value = false;
for col in cols.iter_mut() {
let val = match col {
ColumnIter::Iter(i) => i.next().unwrap(),
ColumnIter::Broadcast(s) => *s,
};

if has_null && !ignore_nulls {
// We know that the result must be null, but we can't just break out of the loop,
// because all cols iterator has to be moved correctly.
continue;
}

if let Some(s) = val {
if found_not_null_value {
buf.push_str(delimiter);
}
buf.push_str(s);
found_not_null_value = true;
} else {
has_null = true;
}
}

if has_null {
if !ignore_nulls && has_null {
builder.append_null();
} else {
builder.append_value(&buf)
Expand Down Expand Up @@ -139,11 +156,11 @@ mod test {
let a = StringChunked::new("a", &["foo", "bar"]);
let b = StringChunked::new("b", &["spam", "ham"]);

let out = hor_str_concat(&[&a, &b], "_").unwrap();
let out = hor_str_concat(&[&a, &b], "_", true).unwrap();
assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]);

let c = StringChunked::new("b", &["literal"]);
let out = hor_str_concat(&[&a, &b, &c], "_").unwrap();
let out = hor_str_concat(&[&a, &b, &c], "_", true).unwrap();
assert_eq!(
Vec::from(&out),
&[Some("foo_spam_literal"), Some("bar_ham_literal")]
Expand Down
22 changes: 16 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use crate::{map, map_as_slice};
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum StringFunction {
#[cfg(feature = "concat_str")]
ConcatHorizontal(String),
ConcatHorizontal {
delimiter: String,
ignore_nulls: bool,
},
#[cfg(feature = "concat_str")]
ConcatVertical {
delimiter: String,
Expand Down Expand Up @@ -124,7 +127,7 @@ impl StringFunction {
use StringFunction::*;
match self {
#[cfg(feature = "concat_str")]
ConcatVertical { .. } | ConcatHorizontal(_) => mapper.with_dtype(DataType::String),
ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String),
#[cfg(feature = "regex")]
Contains { .. } => mapper.with_dtype(DataType::Boolean),
CountMatches(_) => mapper.with_dtype(DataType::UInt32),
Expand Down Expand Up @@ -194,7 +197,7 @@ impl Display for StringFunction {
EndsWith { .. } => "ends_with",
Extract(_) => "extract",
#[cfg(feature = "concat_str")]
ConcatHorizontal(_) => "concat_horizontal",
ConcatHorizontal { .. } => "concat_horizontal",
#[cfg(feature = "concat_str")]
ConcatVertical { .. } => "concat_vertical",
Explode => "explode",
Expand Down Expand Up @@ -318,7 +321,10 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ignore_nulls,
} => map!(strings::concat, &delimiter, ignore_nulls),
#[cfg(feature = "concat_str")]
ConcatHorizontal(delimiter) => map_as_slice!(strings::concat_hor, &delimiter),
ConcatHorizontal {
delimiter,
ignore_nulls,
} => map_as_slice!(strings::concat_hor, &delimiter, ignore_nulls),
#[cfg(feature = "regex")]
Replace { n, literal } => map_as_slice!(strings::replace, literal, n),
#[cfg(feature = "string_reverse")]
Expand Down Expand Up @@ -696,13 +702,17 @@ pub(super) fn concat(s: &Series, delimiter: &str, ignore_nulls: bool) -> PolarsR
}

#[cfg(feature = "concat_str")]
pub(super) fn concat_hor(series: &[Series], delimiter: &str) -> PolarsResult<Series> {
pub(super) fn concat_hor(
series: &[Series],
delimiter: &str,
ignore_nulls: bool,
) -> PolarsResult<Series> {
let str_series: Vec<_> = series
.iter()
.map(|s| s.cast(&DataType::String))
.collect::<PolarsResult<_>>()?;
let cas: Vec<_> = str_series.iter().map(|s| s.str().unwrap()).collect();
Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter)?.into_series())
Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter, ignore_nulls)?.into_series())
}

impl From<StringFunction> for FunctionExpr {
Expand Down
10 changes: 7 additions & 3 deletions crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ use super::*;

#[cfg(all(feature = "concat_str", feature = "strings"))]
/// Horizontally concat string columns in linear time
pub fn concat_str<E: AsRef<[Expr]>>(s: E, separator: &str) -> Expr {
pub fn concat_str<E: AsRef<[Expr]>>(s: E, separator: &str, ignore_nulls: bool) -> Expr {
let input = s.as_ref().to_vec();
let separator = separator.to_string();

Expr::Function {
input,
function: StringFunction::ConcatHorizontal(separator).into(),
function: StringFunction::ConcatHorizontal {
delimiter: separator,
ignore_nulls,
}
.into(),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
Expand Down Expand Up @@ -45,7 +49,7 @@ pub fn format_str<E: AsRef<[Expr]>>(format: &str, args: E) -> PolarsResult<Expr>
}
}

Ok(concat_str(exprs, ""))
Ok(concat_str(exprs, "", false))
}

/// Concat lists entries.
Expand Down
34 changes: 25 additions & 9 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,17 +284,23 @@ fn string_addition_to_linear_concat(
AExpr::Function {
input: input_left,
function:
ref
fun_l @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep_l)),
ref fun_l @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep_l,
ignore_nulls: ignore_nulls_l,
}),
options,
},
AExpr::Function {
input: input_right,
function: FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep_r)),
function:
FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep_r,
ignore_nulls: ignore_nulls_r,
}),
..
},
) => {
if sep_l.is_empty() && sep_r.is_empty() {
if sep_l.is_empty() && sep_r.is_empty() && ignore_nulls_l == ignore_nulls_r {
let mut input = Vec::with_capacity(input_left.len() + input_right.len());
input.extend_from_slice(input_left);
input.extend_from_slice(input_right);
Expand All @@ -312,12 +318,15 @@ fn string_addition_to_linear_concat(
AExpr::Function {
input,
function:
ref fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep)),
ref fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep,
ignore_nulls,
}),
options,
},
_,
) => {
if sep.is_empty() {
if sep.is_empty() && !ignore_nulls {
let mut input = input.clone();
input.push(right_ae);
Some(AExpr::Function {
Expand All @@ -335,11 +344,14 @@ fn string_addition_to_linear_concat(
AExpr::Function {
input: input_right,
function:
ref fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep)),
ref fun @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep,
ignore_nulls,
}),
options,
},
) => {
if sep.is_empty() {
if sep.is_empty() && !ignore_nulls {
let mut input = Vec::with_capacity(1 + input_right.len());
input.push(left_ae);
input.extend_from_slice(input_right);
Expand All @@ -354,7 +366,11 @@ fn string_addition_to_linear_concat(
},
_ => Some(AExpr::Function {
input: vec![left_ae, right_ae],
function: StringFunction::ConcatHorizontal("".to_string()).into(),
function: StringFunction::ConcatHorizontal {
delimiter: "".to_string(),
ignore_nulls: false,
}
.into(),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
Expand Down
31 changes: 20 additions & 11 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,18 @@ pub(super) fn optimize_functions(
},
// flatten nested concat_str calls
#[cfg(all(feature = "strings", feature = "concat_str"))]
function @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep))
if sep.is_empty() =>
{
function @ FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep,
ignore_nulls,
}) if sep.is_empty() => {
if input
.iter()
.any(|node| is_string_concat(expr_arena.get(*node)))
.any(|node| is_string_concat(expr_arena.get(*node), *ignore_nulls))
{
let mut new_inputs = Vec::with_capacity(input.len() * 2);

for node in input {
match get_string_concat_input(*node, expr_arena) {
match get_string_concat_input(*node, expr_arena, *ignore_nulls) {
Some(inp) => new_inputs.extend_from_slice(inp),
None => new_inputs.push(*node),
}
Expand Down Expand Up @@ -89,23 +90,31 @@ pub(super) fn optimize_functions(
}

#[cfg(all(feature = "strings", feature = "concat_str"))]
fn is_string_concat(ae: &AExpr) -> bool {
fn is_string_concat(ae: &AExpr, ignore_nulls: bool) -> bool {
matches!(ae, AExpr::Function {
function:FunctionExpr::StringExpr(
StringFunction::ConcatHorizontal(sep),
StringFunction::ConcatHorizontal{delimiter: sep, ignore_nulls: func_inore_nulls},
),
..
} if sep.is_empty())
} if sep.is_empty() && *func_inore_nulls == ignore_nulls)
}

#[cfg(all(feature = "strings", feature = "concat_str"))]
fn get_string_concat_input(node: Node, expr_arena: &Arena<AExpr>) -> Option<&[Node]> {
fn get_string_concat_input(
node: Node,
expr_arena: &Arena<AExpr>,
ignore_nulls: bool,
) -> Option<&[Node]> {
match expr_arena.get(node) {
AExpr::Function {
input,
function: FunctionExpr::StringExpr(StringFunction::ConcatHorizontal(sep)),
function:
FunctionExpr::StringExpr(StringFunction::ConcatHorizontal {
delimiter: sep,
ignore_nulls: func_ignore_nulls,
}),
..
} if sep.is_empty() => Some(input),
} if sep.is_empty() && *func_ignore_nulls == ignore_nulls => Some(input),
_ => None,
}
}
4 changes: 2 additions & 2 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,14 +839,14 @@ impl SQLFunctionVisitor<'_> {
Concat => if function.args.is_empty() {
polars_bail!(InvalidOperation: "Invalid number of arguments for Concat: 0");
} else {
self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, ""))
self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true))
},
ConcatWS => if function.args.len() < 2 {
polars_bail!(InvalidOperation: "Invalid number of arguments for ConcatWS: {}", function.args.len());
} else {
self.try_visit_variadic(|exprs: &[Expr]| {
match &exprs[0] {
Expr::Literal(LiteralValue::String(s)) => Ok(concat_str(&exprs[1..], s)),
Expr::Literal(LiteralValue::String(s)) => Ok(concat_str(&exprs[1..], s, true)),
_ => polars_bail!(InvalidOperation: "ConcatWS 'separator' must be a literal string; found {:?}", exprs[0]),
}
})
Expand Down
2 changes: 1 addition & 1 deletion docs/src/rust/user-guide/expressions/folds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let out = df
.lazy()
.select([concat_str([col("a"), col("b")], "")])
.select([concat_str([col("a"), col("b")], "", false)])
.collect()?;
println!("{:?}", out);
// --8<-- [end:string]
Expand Down
8 changes: 7 additions & 1 deletion py-polars/polars/functions/as_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def concat_str(
exprs: IntoExpr | Iterable[IntoExpr],
*more_exprs: IntoExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> Expr:
"""
Horizontally concatenate columns into a single string column.
Expand All @@ -492,6 +493,11 @@ def concat_str(
positional arguments.
separator
String that will be used to separate the values of each column.
ignore_nulls
Ignore null values (default).
If set to ``False``, null values will be propagated.
if the row contains any null values, the output is ``None``.
Examples
--------
Expand Down Expand Up @@ -524,7 +530,7 @@ def concat_str(
└─────┴──────┴──────┴───────────────┘
"""
exprs = parse_as_list_of_expressions(exprs, *more_exprs)
return wrap_expr(plr.concat_str(exprs, separator))
return wrap_expr(plr.concat_str(exprs, separator, ignore_nulls))


def format(f_string: str, *args: Expr | str) -> Expr:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ pub fn concat_list(s: Vec<PyExpr>) -> PyResult<PyExpr> {
}

#[pyfunction]
pub fn concat_str(s: Vec<PyExpr>, separator: &str) -> PyExpr {
pub fn concat_str(s: Vec<PyExpr>, separator: &str, ignore_nulls: bool) -> PyExpr {
let s = s.into_iter().map(|e| e.inner).collect::<Vec<_>>();
dsl::concat_str(s, separator).into()
dsl::concat_str(s, separator, ignore_nulls).into()
}

#[pyfunction]
Expand Down
Loading

0 comments on commit 1b70f68

Please sign in to comment.