Skip to content

Commit

Permalink
feat: Expressify str.slice
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Jan 15, 2024
1 parent de405f0 commit c866553
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 95 deletions.
18 changes: 5 additions & 13 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ where
}

#[inline]
pub fn ternary_elementwise<T, U, V, G, F>(
pub fn ternary_elementwise<T, U, G, V, F>(
ca1: &ChunkedArray<T>,
ca2: &ChunkedArray<U>,
ca3: &ChunkedArray<G>,
Expand Down Expand Up @@ -623,9 +623,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = unary_elementwise(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
unary_elementwise(rhs, |b| op(a.clone(), b)).with_name(lhs.name())
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand All @@ -650,9 +648,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.get_unchecked(0) };
let mut out = try_unary_elementwise(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
Ok(try_unary_elementwise(rhs, |b| op(a.clone(), b))?.with_name(lhs.name()))
},
(_, 1) => {
let b = unsafe { rhs.get_unchecked(0) };
Expand Down Expand Up @@ -686,9 +682,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = unary_elementwise_values(rhs, |b| op(a.clone(), b));
out.rename(lhs.name());
out
unary_elementwise_values(rhs, |b| op(a.clone(), b)).with_name(lhs.name())
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
Expand Down Expand Up @@ -722,9 +716,7 @@ where
match (lhs.len(), rhs.len()) {
(1, _) => {
let a = unsafe { lhs.value_unchecked(0) };
let mut out = try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?;
out.rename(lhs.name());
Ok(out)
Ok(try_unary_elementwise_values(rhs, |b| op(a.clone(), b))?.with_name(lhs.name()))
},
(_, 1) => {
let b = unsafe { rhs.value_unchecked(0) };
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,15 @@ pub trait StringNameSpaceImpl: AsString {

/// Slice the string values.
///
/// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, start: i64, length: Option<u64>) -> StringChunked {
let ca = self.as_string();
let iter = ca
.downcast_iter()
.map(|c| substring::utf8_substring(c, start, &length));
StringChunked::from_chunk_iter_like(ca, iter)
/// Determines a substring starting from `offset` and with length `length` of each of the elements in `array`.
/// `offset` can be negative, in which case the start counts from the end of the string.
fn str_slice(&self, offset: &Series, length: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let offset = offset.cast(&DataType::Int64)?;
// we strict cast, otherwise negative value will be treated as a valid length.
let length = length.strict_cast(&DataType::UInt64)?;

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}
}

Expand Down
156 changes: 113 additions & 43 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,121 @@
use arrow::array::Utf8Array;
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

/// Returns a Utf8Array<O> with a substring starting from `start` and with optional length `length` of each of the elements in `array`.
/// `start` can be negative, in which case the start counts from the end of the string.
pub(super) fn utf8_substring(
array: &Utf8Array<i64>,
start: i64,
length: &Option<u64>,
) -> Utf8Array<i64> {
let length = length.map(|v| v as usize);
fn substring_ternary(
opt_str_val: Option<&str>,
opt_start: Option<i64>,
opt_length: Option<u64>,
) -> Option<&str> {
match (opt_str_val, opt_start) {
(Some(str_val), Some(start)) => {
// if `offset` is negative, it counts from the end of the string
let offset = if start >= 0 {
start as usize
} else {
let offset = (0i64 - start) as usize;
str_val
.char_indices()
.rev()
.nth(offset)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};

let iter = array.values_iter().map(|str_val| {
// compute where we should start slicing this entry.
let start = if start >= 0 {
start as usize
} else {
let start = (0i64 - start) as usize;
str_val
.char_indices()
.rev()
.nth(start)
.map(|(idx, _)| idx + 1)
.unwrap_or(0)
};
let mut iter_chars = str_val.char_indices();
if let Some((offset_idx, _)) = iter_chars.nth(offset) {
let len_end = str_val.len() - offset_idx;

let mut iter_chars = str_val.char_indices();
if let Some((start_idx, _)) = iter_chars.nth(start) {
// length of the str
let len_end = str_val.len() - start_idx;
// slice to end of str if no length given
let length = if let Some(length) = opt_length {
length as usize
} else {
len_end
};

// length to slice
let length = length.unwrap_or(len_end);
if length == 0 {
return Some("");
}

if length == 0 {
return "";
}
// compute
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());
let end_idx = iter_chars
.nth(length.saturating_sub(1))
.map(|(idx, _)| idx)
.unwrap_or(str_val.len());

&str_val[start_idx..end_idx]
} else {
""
}
});
Some(&str_val[offset_idx..end_idx])
} else {
Some("")
}
},
_ => None,
}
}

let new = Utf8Array::<i64>::from_trusted_len_values_iter(iter);
new.with_validity(array.validity().cloned())
pub(super) fn substring(
ca: &StringChunked,
start: &Int64Chunked,
length: &UInt64Chunked,
) -> StringChunked {
match (ca.len(), start.len(), length.len()) {
(1, 1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let start = unsafe { start.get_unchecked(0) };
unary_elementwise(length, |length| substring_ternary(str_val, start, length))
.with_name(ca.name())
},
(_, 1, 1) => {
// SAFETY: index `0` is in bound.
let start = unsafe { start.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(ca, |str_val| substring_ternary(str_val, start, length))
},
(1, _, 1) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
unary_elementwise(start, |start| substring_ternary(str_val, start, length))
.with_name(ca.name())
},
(1, len_b, len_c) if len_b == len_c => {
// fn infer<F: for<'a> FnMut(Option<i64>, Option<u64>) -> Option<&'a str>>(f: F) -> F where
// {
// f
// }
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
binary_elementwise(start, length, |start, length| {
substring_ternary(str_val, start, length)
})
},
(len_a, 1, len_c) if len_a == len_c => {
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<u64>) -> Option<&'a str>>(f: F) -> F where
{
f
}
// SAFETY: index `0` is in bound.
let start = unsafe { start.get_unchecked(0) };
binary_elementwise(
ca,
length,
infer(|str_val, length| substring_ternary(str_val, start, length)),
)
},
(len_a, len_b, 1) if len_a == len_b => {
fn infer<F: for<'a> FnMut(Option<&'a str>, Option<i64>) -> Option<&'a str>>(f: F) -> F where
{
f
}
// SAFETY: index `0` is in bound.
let length = unsafe { length.get_unchecked(0) };
binary_elementwise(
ca,
start,
infer(|str_val, start| substring_ternary(str_val, start, length)),
)
},
_ => ternary_elementwise(ca, start, length, substring_ternary),
}
}
24 changes: 10 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub enum StringFunction {
length: usize,
fill_char: char,
},
Slice(i64, Option<u64>),
Slice,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -160,14 +160,8 @@ impl StringFunction {
Base64Encode => mapper.with_same_dtype(),
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase
| Lowercase
| StripChars
| StripCharsStart
| StripCharsEnd
| StripPrefix
| StripSuffix
| Slice(_, _) => mapper.with_same_dtype(),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill { .. } => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -231,7 +225,7 @@ impl Display for StringFunction {
Base64Encode => "base64_encode",
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => "base64_decode",
Slice(_, _) => "slice",
Slice => "slice",
StartsWith { .. } => "starts_with",
StripChars => "strip_chars",
StripCharsStart => "strip_chars_start",
Expand Down Expand Up @@ -340,7 +334,7 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StripSuffix => map_as_slice!(strings::strip_suffix),
#[cfg(feature = "string_to_integer")]
ToInteger(base, strict) => map!(strings::to_integer, base, strict),
Slice(start, length) => map!(strings::str_slice, start, length),
Slice => map_as_slice!(strings::str_slice),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -879,9 +873,11 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult<Se
let ca = s.str()?;
ca.to_integer(base, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &Series, start: i64, length: Option<u64>) -> PolarsResult<Series> {
let ca = s.str()?;
Ok(ca.str_slice(start, length).into_series())
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

#[cfg(feature = "string_encoding")]
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,13 @@ impl StringNameSpace {
}

/// Slice the string values.
pub fn slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Slice(
start, length,
)))
pub fn slice(self, offset: Expr, length: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Slice),
&[offset, length],
false,
false,
)
}

pub fn explode(self) -> Expr {
Expand Down
16 changes: 8 additions & 8 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ impl SQLFunctionVisitor<'_> {
#[cfg(feature = "nightly")]
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().slice(0, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Ok(e.str().slice(lit(0), match length {
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Left: {}", function.args[1]);
}
Expand Down Expand Up @@ -905,11 +905,11 @@ impl SQLFunctionVisitor<'_> {
Reverse => self.visit_unary(|e| e.str().reverse()),
Right => self.try_visit_binary(|e, length| {
Ok(e.str().slice( match length {
Expr::Literal(LiteralValue::Int64(n)) => -n,
Expr::Literal(LiteralValue::Int64(n)) => lit(-n),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Right: {}", function.args[1]);
}
}, None))
}, lit(Null)))
}),
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))),
Expand All @@ -925,19 +925,19 @@ impl SQLFunctionVisitor<'_> {
2 => self.try_visit_binary(|e, start| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n - 1 ,
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1) ,
_ => polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]),
}, None))
}, lit(Null)))
}),
3 => self.try_visit_ternary(|e, start, length| {
Ok(e.str().slice(
match start {
Expr::Literal(LiteralValue::Int64(n)) => n - 1,
Expr::Literal(LiteralValue::Int64(n)) => lit(n - 1),
_ => {
polars_bail!(InvalidOperation: "Invalid 'start' for Substring: {}", function.args[1]);
}
}, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Expr::Literal(LiteralValue::Int64(n)) => lit(n as u64),
_ => {
polars_bail!(InvalidOperation: "Invalid 'length' for Substring: {}", function.args[2]);
}
Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,9 @@ def reverse(self) -> Expr:
"""
return wrap_expr(self._pyexpr.str_reverse())

def slice(self, offset: int, length: int | None = None) -> Expr:
def slice(
self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None
) -> Expr:
"""
Create subslices of the string values of a String Series.
Expand Down Expand Up @@ -2079,6 +2081,8 @@ def slice(self, offset: int, length: int | None = None) -> Expr:
│ dragonfruit ┆ onf │
└─────────────┴──────────┘
"""
offset = parse_as_expression(offset)
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def explode(self) -> Expr:
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,9 @@ def reverse(self) -> Series:
]
"""

def slice(self, offset: int, length: int | None = None) -> Series:
def slice(
self, offset: int | IntoExprColumn, length: int | IntoExprColumn | None = None
) -> Series:
"""
Create subslices of the string values of a String Series.
Expand Down
Loading

0 comments on commit c866553

Please sign in to comment.