diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index cf630903fbcb..4e5e8e589f08 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -623,9 +623,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.get_unchecked(0) }; - let mut out = unary_elementwise(rhs, |b| op(a.clone(), b)); - out.rename(lhs.name()); - out + unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name()) }, (_, 1) => { let b = unsafe { rhs.get_unchecked(0) }; @@ -650,9 +648,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.get_unchecked(0) }; - let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?; - out.rename(lhs.name()); - Ok(out) + Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name())) }, (_, 1) => { let b = unsafe { rhs.get_unchecked(0) }; @@ -686,9 +682,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.value_unchecked(0) }; - let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b)); - out.rename(lhs.name()); - out + unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name()) }, (_, 1) => { let b = unsafe { rhs.value_unchecked(0) }; @@ -722,9 +716,7 @@ where match (lhs.len(), rhs.len()) { (1, _) => { let a = unsafe { lhs.value_unchecked(0) }; - let mut out = try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?; - out.rename(lhs.name()); - Ok(out) + Ok(try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?.with_name(lhs.name())) }, (_, 1) => { let b = unsafe { rhs.value_unchecked(0) }; diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 35d3fff667fe..c1309658e138 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -605,14 +605,15 @@ pub trait StringNameSpaceImpl: AsString { /// Slice the string values. /// - /// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`. - /// `start` can be negative, in which case the start counts from the end of the string. - fn str_slice(&self, start: i64, length: Option) -> StringChunked { - let ca = self.as_string(); - let iter = ca - .downcast_iter() - .map(|c| substring::utf8_substring(c, start, &length)); - StringChunked::from_chunk_iter_like(ca, iter) + /// Determines a substring starting from `offset` and with length `length` of each of the elements in `array`. + /// `offset` can be negative, in which case the start counts from the end of the string. + fn str_slice(&self, offset: &Series, length: &Series) -> PolarsResult { + let ca = self.as_string(); + let offset = offset.cast(&DataType::Int64)?; + // We strict cast, otherwise negative value will be treated as a valid length. + let length = length.strict_cast(&DataType::UInt64)?; + + Ok(substring::substring(ca, offset.i64()?, length.u64()?)) } } diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs index e485e25dd216..690567396fb8 100644 --- a/crates/polars-ops/src/chunked_array/strings/substring.rs +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -1,51 +1,117 @@ -use arrow::array::Utf8Array; +use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise}; +use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked}; -/// Returns a Utf8Array with a substring starting from `start` and with optional length `length` of each of the elements in `array`. -/// `start` can be negative, in which case the start counts from the end of the string. -pub(super) fn utf8_substring( - array: &Utf8Array, - start: i64, - length: &Option, -) -> Utf8Array { - let length = length.map(|v| v as usize); +fn substring_ternary( + opt_str_val: Option<&str>, + opt_offset: Option, + opt_length: Option, +) -> Option<&str> { + match (opt_str_val, opt_offset) { + (Some(str_val), Some(offset)) => { + // If `offset` is negative, it counts from the end of the string. + let offset = if offset >= 0 { + offset as usize + } else { + let offset = (0i64 - offset) as usize; + str_val + .char_indices() + .rev() + .nth(offset) + .map(|(idx, _)| idx + 1) + .unwrap_or(0) + }; - let iter = array.values_iter().map(|str_val| { - // compute where we should start slicing this entry. - let start = if start >= 0 { - start as usize - } else { - let start = (0i64 - start) as usize; - str_val - .char_indices() - .rev() - .nth(start) - .map(|(idx, _)| idx + 1) - .unwrap_or(0) - }; + let mut iter_chars = str_val.char_indices(); + if let Some((offset_idx, _)) = iter_chars.nth(offset) { + let len_end = str_val.len() - offset_idx; - let mut iter_chars = str_val.char_indices(); - if let Some((start_idx, _)) = iter_chars.nth(start) { - // length of the str - let len_end = str_val.len() - start_idx; + // Slice to end of str if no length given. + let length = if let Some(length) = opt_length { + length as usize + } else { + len_end + }; - // length to slice - let length = length.unwrap_or(len_end); + if length == 0 { + return Some(""); + } - if length == 0 { - return ""; - } - // compute - let end_idx = iter_chars - .nth(length.saturating_sub(1)) - .map(|(idx, _)| idx) - .unwrap_or(str_val.len()); + let end_idx = iter_chars + .nth(length.saturating_sub(1)) + .map(|(idx, _)| idx) + .unwrap_or(str_val.len()); - &str_val[start_idx..end_idx] - } else { - "" - } - }); + Some(&str_val[offset_idx..end_idx]) + } else { + Some("") + } + }, + _ => None, + } +} - let new = Utf8Array::::from_trusted_len_values_iter(iter); - new.with_validity(array.validity().cloned()) +pub(super) fn substring( + ca: &StringChunked, + offset: &Int64Chunked, + length: &UInt64Chunked, +) -> StringChunked { + match (ca.len(), offset.len(), length.len()) { + (1, 1, _) => { + // SAFETY: index `0` is in bound. + let str_val = unsafe { ca.get_unchecked(0) }; + // SAFETY: index `0` is in bound. + let offset = unsafe { offset.get_unchecked(0) }; + unary_elementwise(length, |length| substring_ternary(str_val, offset, length)) + .with_name(ca.name()) + }, + (_, 1, 1) => { + // SAFETY: index `0` is in bound. + let offset = unsafe { offset.get_unchecked(0) }; + // SAFETY: index `0` is in bound. + let length = unsafe { length.get_unchecked(0) }; + unary_elementwise(ca, |str_val| substring_ternary(str_val, offset, length)) + }, + (1, _, 1) => { + // SAFETY: index `0` is in bound. + let str_val = unsafe { ca.get_unchecked(0) }; + // SAFETY: index `0` is in bound. + let length = unsafe { length.get_unchecked(0) }; + unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length)) + .with_name(ca.name()) + }, + (1, len_b, len_c) if len_b == len_c => { + // SAFETY: index `0` is in bound. + let str_val = unsafe { ca.get_unchecked(0) }; + binary_elementwise(offset, length, |offset, length| { + substring_ternary(str_val, offset, length) + }) + }, + (len_a, 1, len_c) if len_a == len_c => { + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where + { + f + } + // SAFETY: index `0` is in bound. + let offset = unsafe { offset.get_unchecked(0) }; + binary_elementwise( + ca, + length, + infer(|str_val, length| substring_ternary(str_val, offset, length)), + ) + }, + (len_a, len_b, 1) if len_a == len_b => { + fn infer FnMut(Option<&'a str>, Option) -> Option<&'a str>>(f: F) -> F where + { + f + } + // SAFETY: index `0` is in bound. + let length = unsafe { length.get_unchecked(0) }; + binary_elementwise( + ca, + offset, + infer(|str_val, offset| substring_ternary(str_val, offset, length)), + ) + }, + _ => ternary_elementwise(ca, offset, length, substring_ternary), + } } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 0b63c717c97c..2b150398e425 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -77,7 +77,7 @@ pub enum StringFunction { length: usize, fill_char: char, }, - Slice(i64, Option), + Slice, #[cfg(feature = "string_encoding")] HexEncode, #[cfg(feature = "binary_encoding")] @@ -160,14 +160,8 @@ impl StringFunction { Base64Encode => mapper.with_same_dtype(), #[cfg(feature = "binary_encoding")] Base64Decode(_) => mapper.with_dtype(DataType::Binary), - Uppercase - | Lowercase - | StripChars - | StripCharsStart - | StripCharsEnd - | StripPrefix - | StripSuffix - | Slice(_, _) => mapper.with_same_dtype(), + Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix + | StripSuffix | Slice => mapper.with_same_dtype(), #[cfg(feature = "string_pad")] PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(), #[cfg(feature = "dtype-struct")] @@ -231,7 +225,7 @@ impl Display for StringFunction { Base64Encode => "base64_encode", #[cfg(feature = "binary_encoding")] Base64Decode(_) => "base64_decode", - Slice(_, _) => "slice", + Slice => "slice", StartsWith { .. } => "starts_with", StripChars => "strip_chars", StripCharsStart => "strip_chars_start", @@ -340,7 +334,7 @@ impl From for SpecialEq> { StripSuffix => map_as_slice!(strings::strip_suffix), #[cfg(feature = "string_to_integer")] ToInteger(base, strict) => map!(strings::to_integer, base, strict), - Slice(start, length) => map!(strings::str_slice, start, length), + Slice => map_as_slice!(strings::str_slice), #[cfg(feature = "string_encoding")] HexEncode => map!(strings::hex_encode), #[cfg(feature = "binary_encoding")] @@ -879,9 +873,22 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult) -> PolarsResult { - let ca = s.str()?; - Ok(ca.str_slice(start, length).into_series()) +pub(super) fn str_slice(s: &[Series]) -> PolarsResult { + // Calculate the post-broadcast length and ensure everything is consistent. + let len = s + .iter() + .map(|series| series.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + polars_ensure!( + s.iter().all(|series| series.len() == 1 || series.len() == len), + ComputeError: "all series in `str_slice` should have equal or unit length" + ); + let ca = s[0].str()?; + let offset = &s[1]; + let length = &s[2]; + Ok(ca.str_slice(offset, length)?.into_series()) } #[cfg(feature = "string_encoding")] diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 742df67ca773..5b4ae8a1f05e 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -516,11 +516,13 @@ impl StringNameSpace { } /// Slice the string values. - pub fn slice(self, start: i64, length: Option) -> Expr { - self.0 - .map_private(FunctionExpr::StringExpr(StringFunction::Slice( - start, length, - ))) + pub fn slice(self, offset: Expr, length: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::Slice), + &[offset, length], + false, + false, + ) } pub fn explode(self) -> Expr { diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 1bf567c6189d..bc3138b006d3 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -855,8 +855,8 @@ impl SQLFunctionVisitor<'_> { #[cfg(feature = "nightly")] InitCap => self.visit_unary(|e| e.str().to_titlecase()), Left => self.try_visit_binary(|e, length| { - Ok(e.str().slice(0, match length { - Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64), + Ok(e.str().slice(lit(0), match length { + Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64), _ => { polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]); } @@ -905,11 +905,11 @@ impl SQLFunctionVisitor<'_> { Reverse => self.visit_unary(|e| e.str().reverse()), Right => self.try_visit_binary(|e, length| { Ok(e.str().slice( match length { - Expr::Literal(LiteralValue::Int64(n)) => -n, + Expr::Literal(LiteralValue::Int64(n)) => lit(-n), _ => { polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]); } - }, None)) + }, lit(Null))) }), RTrim => match function.args.len() { 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), @@ -925,19 +925,19 @@ impl SQLFunctionVisitor<'_> { 2 => self.try_visit_binary(|e, start| { Ok(e.str().slice( match start { - Expr::Literal(LiteralValue::Int64(n)) => n - 1 , + Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) , _ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]), - }, None)) + }, lit(Null))) }), 3 => self.try_visit_ternary(|e, start, length| { Ok(e.str().slice( match start { - Expr::Literal(LiteralValue::Int64(n)) => n - 1, + Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1), _ => { polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]); } }, match length { - Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64), + Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64), _ => { polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]); } diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 6e9e372d90a9..386fff4ae88a 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2027,7 +2027,9 @@ def reverse(self) -> Expr: """ return wrap_expr(self._pyexpr.str_reverse()) - def slice(self, offset: int, length: int | None = None) -> Expr: + def slice( + self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None + ) -> Expr: """ Create subslices of the string values of a String Series. @@ -2079,6 +2081,8 @@ def slice(self, offset: int, length: int | None = None) -> Expr: │ dragonfruit ┆ onf │ └─────────────┴──────────┘ """ + offset = parse_as_expression(offset) + length = parse_as_expression(length) return wrap_expr(self._pyexpr.str_slice(offset, length)) def explode(self) -> Expr: diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 607d7aa4fa29..7ab2c1a9ce19 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1516,7 +1516,9 @@ def reverse(self) -> Series: ] """ - def slice(self, offset: int, length: int | None = None) -> Series: + def slice( + self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None + ) -> Series: """ Create subslices of the string values of a String Series. diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index 3a5602a8a598..aba3cc748244 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -94,8 +94,12 @@ impl PyExpr { self.inner.clone().str().strip_suffix(suffix.inner).into() } - fn str_slice(&self, start: i64, length: Option) -> Self { - self.inner.clone().str().slice(start, length).into() + fn str_slice(&self, offset: Self, length: Self) -> Self { + self.inner + .clone() + .str() + .slice(offset.inner, length.inner) + .into() } fn str_explode(&self) -> Self { diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/namespaces/string/test_string.py index 3538a5986cc4..daeb157a276c 100644 --- a/py-polars/tests/unit/namespaces/string/test_string.py +++ b/py-polars/tests/unit/namespaces/string/test_string.py @@ -16,6 +16,39 @@ def test_str_slice() -> None: assert df.select([pl.col("a").str.slice(2, 4)])["a"].to_list() == ["obar", "rfoo"] +def test_str_slice_expr() -> None: + df = pl.DataFrame( + { + "a": ["foobar", None, "barfoo", "abcd", ""], + "offset": [1, 3, None, -3, 2], + "length": [3, 4, 2, None, 2], + } + ) + out = df.select( + all_expr=pl.col("a").str.slice("offset", "length"), + offset_expr=pl.col("a").str.slice("offset", 2), + length_expr=pl.col("a").str.slice(0, "length"), + length_none=pl.col("a").str.slice("offset", None), + offset_length_lit=pl.col("a").str.slice(-3, 3), + str_lit=pl.lit("qwert").str.slice("offset", "length"), + ) + expected = pl.DataFrame( + { + "all_expr": ["oob", None, None, "bcd", ""], + "offset_expr": ["oo", None, None, "bc", ""], + "length_expr": ["foo", None, "ba", "abcd", ""], + "length_none": ["oobar", None, None, "bcd", ""], + "offset_length_lit": ["bar", None, "foo", "bcd", ""], + "str_lit": ["wer", "rt", None, "ert", "er"], + } + ) + assert_frame_equal(out, expected) + + # negative length is not allowed + with pytest.raises(pl.ComputeError): + df.select(pl.col("a").str.slice(0, -1)) + + def test_str_concat() -> None: s = pl.Series(["1", None, "2", None]) # propagate null