Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expressify str.slice #13747

Merged
merged 2 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = unary_elementwise(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name())
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand All @@ -650,9 +648,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name()))
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand Down Expand Up @@ -686,9 +682,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name())
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
Expand Down Expand Up @@ -722,9 +716,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
Ok(try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?.with_name(lhs.name()))
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,15 @@ pub trait StringNameSpaceImpl: AsString {

/// Slice the string values.
///
/// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, start: i64, length: Option<u64>) -> StringChunked {
let ca = self.as_string();
let iter = ca
.downcast_iter()
.map(|c| substring::utf8_substring(c, start, &length));
StringChunked::from_chunk_iter_like(ca, iter)
/// Determines a substring starting from `offset` and with length `length` of each of the elements in `array`.
/// `offset` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, offset: &Series, length: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let offset = offset.cast(&DataType::Int64)?;
// We strict cast, otherwise negative value will be treated as a valid length.
let length = length.strict_cast(&DataType::UInt64)?;

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}
}

Expand Down
152 changes: 109 additions & 43 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,117 @@
use arrow::array::Utf8Array;
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

/// Returns a Utf8Array<O> with a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
pub(super) fn utf8_substring(
array: &Utf8Array<i64>,
start: i64,
length: &Option<u64>,
) -> Utf8Array<i64> {
let length = length.map(|v| v as usize);
fn substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
opt_length: Option<u64>,
) -> Option<&str> {
match (opt_str_val, opt_offset) {
(Some(str_val), Some(offset)) => {
// If `offset` is negative, it counts from the end of the string.
let offset = if offset >= 0 {
offset as usize
} else {
let offset = (0i64 - offset) as usize;
str_val
.char_indices()
.rev()
.nth(offset)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};

let iter = array.values_iter().map(|str_val| {
// compute where we should start slicing this entry.
let start = if start >= 0 {
start as usize
} else {
let start = (0i64 - start) as usize;
str_val
.char_indices()
.rev()
.nth(start)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};
let mut iter_chars = str_val.char_indices();
if let Some((offset_idx, _)) = iter_chars.nth(offset) {
let len_end = str_val.len() - offset_idx;

let mut iter_chars = str_val.char_indices();
if let Some((start_idx, _)) = iter_chars.nth(start) {
// length of the str
let len_end = str_val.len() - start_idx;
// Slice to end of str if no length given.
let length = if let Some(length) = opt_length {
length as usize
} else {
len_end
};

// length to slice
let length = length.unwrap_or(len_end);
if length == 0 {
return Some("");
}

if length == 0 {
return "";
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

&str_val[start_idx..end_idx]
} else {
""
}
});
Some(&str_val[offset_idx..end_idx])
} else {
Some("")
}
},
_ => None,
}
}

let new = Utf8Array::<i64>::from_trusted_len_values_iter(iter);
new.with_validity(array.validity().cloned())
pub(super) fn substring(
ca: &StringChunked,
offset: &Int64Chunked,
length: &UInt64Chunked,
) -> StringChunked {
match (ca.len(), offset.len(), length.len()) {
(1, 1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let offset = unsafe { offset.get_unchecked(0) };
unary_elementwise(length, |length| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(_, 1, 1) => {
// SAFETY: index `0` is in bound.
let offset = unsafe { offset.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(ca, |str_val| substring_ternary(str_val, offset, length))
},
(1, _, 1) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(offset, |offset| substring_ternary(str_val, offset, length))
.with_name(ca.name())
},
(1, len_b, len_c) if len_b == len_c => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
binary_elementwise(offset, length, |offset, length| {
substring_ternary(str_val, offset, length)
})
},
(len_a, 1, len_c) if len_a == len_c => {
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is a bit annoying, so we have to introduce this infer function to keep it happy.

{
f
}
// SAFETY: index `0` is in bound.
let offset = unsafe { offset.get_unchecked(0) };
binary_elementwise(
ca,
length,
infer(|str_val, length| substring_ternary(str_val, offset, length)),
)
},
(len_a, len_b, 1) if len_a == len_b => {
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where
{
f
}
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
binary_elementwise(
ca,
offset,
infer(|str_val, offset| substring_ternary(str_val, offset, length)),
)
},
_ => ternary_elementwise(ca, offset, length, substring_ternary),
}
}
35 changes: 21 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub enum StringFunction {
length: usize,
fill_char: char,
},
Slice(i64, Option<u64>),
Slice,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -160,14 +160,8 @@ impl StringFunction {
Base64Encode => mapper.with_same_dtype(),
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase
| Lowercase
| StripChars
| StripCharsStart
| StripCharsEnd
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -231,7 +225,7 @@ impl Display for StringFunction {
Base64Encode => "base64_encode",
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => "base64_decode",
Slice(_, _) => "slice",
Slice => "slice",
StartsWith { .. } => "starts_with",
StripChars => "strip_chars",
StripCharsStart => "strip_chars_start",
Expand Down Expand Up @@ -340,7 +334,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_to_integer")]
ToInteger(base, strict) => map!(strings::to_integer, base, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Slice => map_as_slice!(strings::str_slice),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -879,9 +873,22 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult<Se
let ca = s.str()?;
ca.to_integer(base, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &Series, start: i64, length: Option<u64>) -> PolarsResult<Series> {
let ca = s.str()?;
Ok(ca.str_slice(start, length).into_series())
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
// Calculate the post-broadcast length and ensure everything is consistent.
let len = s
.iter()
.map(|series| series.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
polars_ensure!(
s.iter().all(|series| series.len() == 1 || series.len() == len),
ComputeError: "all series in `str_slice` should have equal or unit length"
);
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

#[cfg(feature = "string_encoding")]
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,13 @@ impl StringNameSpace {
}

/// Slice the string values.
pub fn slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Slice(
start, length,
)))
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Slice),
&[offset, length],
false,
false,
)
}

pub fn explode(self) -> Expr {
Expand Down
16 changes: 8 additions & 8 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ impl SQLFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(0, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Ok(e.str().slice(lit(0), match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
}
Expand Down Expand Up @@ -905,11 +905,11 @@ impl SQLFunctionVisitor<'_> {
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => -n,
Expr::Literal(LiteralValue::Int64(n)) => lit(-n),
Copy link
Collaborator Author

@reswqa reswqa Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexander-beedie I've mostly left the SQL part as it is, so feel free to push a new commit or open a new PR if you want to improve it to work better with expressions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do; thx! :))

_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
}
}, None))
}, lit(Null)))
}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
Expand All @@ -925,19 +925,19 @@ impl SQLFunctionVisitor<'_> {
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n - 1 ,
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) ,
_ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
}, None))
}, lit(Null)))
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n - 1,
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1),
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,9 @@ def reverse(self) -> Expr:
"""
return wrap_expr(self._pyexpr.str_reverse())

def slice(self, offset: int, length: int | None = None) -> Expr:
def slice(
self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None
) -> Expr:
"""
Create subslices of the string values of a String Series.

Expand Down Expand Up @@ -2079,6 +2081,8 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
│ dragonfruit ┆ onf │
└─────────────┴──────────┘
"""
offset = parse_as_expression(offset)
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def explode(self) -> Expr:
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,9 @@ def reverse(self) -> Series:
]
"""

def slice(self, offset: int, length: int | None = None) -> Series:
def slice(
self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None
) -> Series:
"""
Create subslices of the string values of a String Series.

Expand Down
Loading