Skip to content

Commit

Permalink
feat(rust, python): allow expr in str.contains (#6443)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Robin <monsieurgabrielrobin@gmail.com>
Co-authored-by: gab23r <106454081+gab23r@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 25, 2023
1 parent 87b96f2 commit b34dc65
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 29 deletions.
5 changes: 2 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
fn from(func: StringFunction) -> Self {
use StringFunction::*;
match func {
Contains { pat, literal } => {
map!(strings::contains, &pat, literal)
}
#[cfg(feature = "regex")]
Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
EndsWith { .. } => map_as_slice!(strings::ends_with),
StartsWith { .. } => map_as_slice!(strings::starts_with),
Extract { pat, group_index } => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ impl FunctionExpr {
StringExpr(s) => {
use StringFunction::*;
match s {
Contains { .. } | EndsWith | StartsWith => with_dtype(DataType::Boolean),
#[cfg(feature = "regex")]
Contains { .. } => with_dtype(DataType::Boolean),
EndsWith | StartsWith => with_dtype(DataType::Boolean),
Extract { .. } => same_type(),
ExtractAll => with_dtype(DataType::List(Box::new(DataType::Utf8))),
CountMatch(_) => with_dtype(DataType::UInt32),
Expand Down
64 changes: 56 additions & 8 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ use super::*;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum StringFunction {
#[cfg(feature = "regex")]
Contains {
pat: String,
literal: bool,
strict: bool,
},
StartsWith,
EndsWith,
Expand Down Expand Up @@ -58,6 +59,7 @@ impl Display for StringFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use self::*;
let s = match self {
#[cfg(feature = "regex")]
StringFunction::Contains { .. } => "contains",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::EndsWith { .. } => "ends_with",
Expand Down Expand Up @@ -99,13 +101,59 @@ pub(super) fn lowercase(s: &Series) -> PolarsResult<Series> {
Ok(ca.to_lowercase().into_series())
}

pub(super) fn contains(s: &Series, pat: &str, literal: bool) -> PolarsResult<Series> {
let ca = s.utf8()?;
if literal {
ca.contains_literal(pat).map(|ca| ca.into_series())
} else {
ca.contains(pat).map(|ca| ca.into_series())
}
#[cfg(feature = "regex")]
pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult<Series> {
let ca = &s[0].utf8()?;
let pat = &s[1].utf8()?;

let mut out: BooleanChunked = match pat.len() {
1 => match pat.get(0) {
Some(pat) => {
if literal {
ca.contains_literal(pat)?
} else {
ca.contains(pat)?
}
}
None => BooleanChunked::full(ca.name(), false, ca.len()),
},
_ => {
if literal {
ca.into_iter()
.zip(pat.into_iter())
.map(|(opt_src, opt_val)| match (opt_src, opt_val) {
(Some(src), Some(pat)) => src.contains(pat),
_ => false,
})
.collect_trusted()
} else if strict {
ca.into_iter()
.zip(pat.into_iter())
.map(|(opt_src, opt_val)| match (opt_src, opt_val) {
(Some(src), Some(pat)) => {
let re = Regex::new(pat)?;
Ok(re.is_match(src))
}
_ => Ok(false),
})
.collect::<PolarsResult<_>>()?
} else {
ca.into_iter()
.zip(pat.into_iter())
.map(|(opt_src, opt_val)| match (opt_src, opt_val) {
(Some(src), Some(pat)) => {
let re = Regex::new(pat).ok()?;
Some(re.is_match(src))
}
_ => Some(false),
})
.collect_trusted()
}
}
};

out.rename(ca.name());
Ok(out.into_series())
}

pub(super) fn ends_with(s: &[Series]) -> PolarsResult<Series> {
Expand Down
29 changes: 18 additions & 11 deletions polars/polars-lazy/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@ pub struct StringNameSpace(pub(crate) Expr);

impl StringNameSpace {
/// Check if a string value contains a literal substring.
pub fn contains_literal<S: AsRef<str>>(self, pat: S) -> Expr {
let pat = pat.as_ref().into();
self.0
.map_private(StringFunction::Contains { pat, literal: true }.into())
#[cfg(feature = "regex")]
pub fn contains_literal(self, pat: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Contains {
literal: true,
strict: false,
}),
&[pat],
true,
)
}

/// Check if a string value contains a Regex substring.
pub fn contains<S: AsRef<str>>(self, pat: S) -> Expr {
let pat = pat.as_ref().into();
self.0.map_private(
StringFunction::Contains {
pat,
#[cfg(feature = "regex")]
pub fn contains(self, pat: Expr, strict: bool) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Contains {
literal: false,
}
.into(),
strict,
}),
&[pat],
true,
)
}

Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/internals/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def rjust(self, width: int, fillchar: str = " ") -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.str_rjust(width, fillchar))

def contains(self, pattern: str, literal: bool = False) -> pli.Expr:
def contains(
self, pattern: str | pli.Expr, literal: bool = False, strict: bool = True
) -> pli.Expr:
"""
Check if string contains a substring that matches a regex.
Expand All @@ -503,6 +505,9 @@ def contains(self, pattern: str, literal: bool = False) -> pli.Expr:
A valid regex pattern.
literal
Treat pattern as a literal string.
strict
Raise an error if the underlying pattern is not a valid regex expression,
otherwise mask out with a null value.
Examples
--------
Expand Down Expand Up @@ -532,7 +537,8 @@ def contains(self, pattern: str, literal: bool = False) -> pli.Expr:
ends_with : Check if string values end with a substring.
"""
return pli.wrap_expr(self._pyexpr.str_contains(pattern, literal))
pattern = pli.expr_to_lit_or_expr(pattern, str_to_lit=True)._pyexpr
return pli.wrap_expr(self._pyexpr.str_contains(pattern, literal, strict))

def ends_with(self, sub: str | pli.Expr) -> pli.Expr:
"""
Expand Down
7 changes: 6 additions & 1 deletion py-polars/polars/internals/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def concat(self, delimiter: str = "-") -> pli.Series:
"""

def contains(self, pattern: str, literal: bool = False) -> pli.Series:
def contains(
self, pattern: str, literal: bool = False, strict: bool = True
) -> pli.Series:
"""
Check if strings in Series contain a substring that matches a regex.
Expand All @@ -184,6 +186,9 @@ def contains(self, pattern: str, literal: bool = False) -> pli.Series:
A valid regex pattern.
literal
Treat pattern as a literal string.
strict
Raise an error if the underlying pattern is not a valid regex expression,
otherwise mask out with a null value.
Returns
-------
Expand Down
6 changes: 3 additions & 3 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,10 @@ impl PyExpr {
self.clone().inner.str().rjust(width, fillchar).into()
}

pub fn str_contains(&self, pat: String, literal: Option<bool>) -> PyExpr {
pub fn str_contains(&self, pat: PyExpr, literal: Option<bool>, strict: bool) -> PyExpr {
match literal {
Some(true) => self.inner.clone().str().contains_literal(pat).into(),
_ => self.inner.clone().str().contains(pat).into(),
Some(true) => self.inner.clone().str().contains_literal(pat.inner).into(),
_ => self.inner.clone().str().contains(pat.inner, strict).into(),
}
}

Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,41 @@ def test_contains() -> None:
)


def test_contains_expr() -> None:
df = pl.DataFrame(
{
"text": [
"some text",
"(with) special\n .* chars",
"**etc...?$",
None,
"b",
"invalid_regex",
],
"pattern": [r"[me]", r".*", r"^\(", "a", None, "*"],
}
)

assert df.select(
[
pl.col("text")
.str.contains(pl.col("pattern"), literal=False, strict=False)
.alias("contains"),
pl.col("text")
.str.contains(pl.col("pattern"), literal=True)
.alias("contains_lit"),
]
).to_dict(False) == {
"contains": [True, True, False, False, False, None],
"contains_lit": [False, True, False, False, False, False],
}

with pytest.raises(pl.ComputeError):
df.select(
pl.col("text").str.contains(pl.col("pattern"), literal=False, strict=True)
)


def test_null_comparisons() -> None:
s = pl.Series("s", [None, "str", "a"])
assert (s.shift() == s).null_count() == 0
Expand Down

0 comments on commit b34dc65

Please sign in to comment.