Skip to content

Commit

Permalink
Add support for working with bytes and bytearray
Browse files Browse the repository at this point in the history
This makes it useful for searching raw binary data as well
  • Loading branch information
Isaac Garzon committed Jan 26, 2023
1 parent d10fd3f commit 4ca8c26
Showing 1 changed file with 121 additions and 34 deletions.
155 changes: 121 additions & 34 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, Match, MatchKind};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyUnicode};
use pyo3::{
exceptions::{PyTypeError, PyValueError},
prelude::*,
types::{PyByteArray, PyBytes, PyUnicode},
};

/// A Python wrapper for AhoCorasick.
#[pyclass(name = "AhoCorasick")]
struct PyAhoCorasick {
ac_impl: AhoCorasick,
patterns: Vec<Py<PyUnicode>>,
patterns: Vec<Py<PyAny>>,
is_unicode: bool,
}

impl PyAhoCorasick {
Expand All @@ -17,7 +22,7 @@ impl PyAhoCorasick {
}

/// Return matches for a given haystack.
fn get_matches(&self, py: Python<'_>, haystack: &str, overlapping: bool) -> Vec<Match> {
fn get_matches(&self, py: Python<'_>, haystack: &[u8], overlapping: bool) -> Vec<Match> {
let ac_impl = &self.ac_impl;
py.allow_threads(|| {
if overlapping {
Expand All @@ -35,7 +40,7 @@ impl PyAhoCorasick {
/// __new__() implementation.
#[new]
#[args(matchkind = "\"MATCHKIND_STANDARD\"")]
fn new(py: Python, patterns: Vec<Py<PyUnicode>>, matchkind: &str) -> PyResult<Self> {
fn new(py: Python, patterns: Vec<Py<PyAny>>, matchkind: &str) -> PyResult<Self> {
let matchkind = match matchkind {
"MATCHKIND_STANDARD" => MatchKind::Standard,
"MATCHKIND_LEFTMOST_FIRST" => MatchKind::LeftmostFirst,
Expand All @@ -46,9 +51,32 @@ impl PyAhoCorasick {
));
}
};
let mut rust_patterns: Vec<String> = vec![];

if !patterns.windows(2).all(|w| {
w[0].as_ref(py)
.is_instance(w[1].as_ref(py).get_type())
.unwrap_or(false)
}) {
return Err(PyTypeError::new_err(
"all of the patterns must be of the same type",
));
}

let mut rust_patterns: Vec<Vec<u8>> = vec![];
let mut is_unicode = false;
for s in patterns.iter() {
rust_patterns.push(s.as_ref(py).extract()?);
let pat = s.as_ref(py);
if pat.is_instance_of::<PyUnicode>()? {
is_unicode = true;
rust_patterns.push(s.as_ref(py).extract::<&str>()?.bytes().collect());
} else if !pat.is_instance_of::<PyBytes>()? && !pat.is_instance_of::<PyByteArray>()? {
return Err(PyTypeError::new_err(format!(
"expected str, bytes, or bytearray, got {}",
pat.get_type().name()?
)));
} else {
rust_patterns.push(s.as_ref(py).extract()?);
}
}
Ok(Self {
ac_impl: py.allow_threads(|| {
Expand All @@ -58,6 +86,7 @@ impl PyAhoCorasick {
.build(rust_patterns)
}),
patterns,
is_unicode,
})
}

Expand All @@ -66,47 +95,105 @@ impl PyAhoCorasick {
#[args(overlapping = "false")]
fn find_matches_as_indexes(
self_: PyRef<Self>,
haystack: &str,
haystack: Py<PyAny>,
overlapping: bool,
) -> PyResult<Vec<(usize, usize, usize)>> {
self_.check_overlapping(overlapping)?;
// Map UTF-8 byte index to Unicode code point index; the latter is what
// Python users expect.
let mut byte_to_code_point = vec![usize::MAX; haystack.len() + 1];
let mut max_codepoint = 0;
for (codepoint_off, (byte_off, _)) in haystack.char_indices().enumerate() {
byte_to_code_point[byte_off] = codepoint_off;
max_codepoint = codepoint_off;
}
// End index is exclusive (e.g. 0:3 is first 3 characters), so handle
// the case where pattern is at end of string.
if !haystack.is_empty() {
byte_to_code_point[haystack.len()] = max_codepoint + 1;
}

let py = self_.py();
let matches = self_.get_matches(py, haystack, overlapping);
Ok(matches
.into_iter()
.map(|m| {
(
m.pattern(),
byte_to_code_point[m.start()],
byte_to_code_point[m.end()],
)
})
.collect())
let haystack = haystack.as_ref(py);
if self_.is_unicode {
if !haystack.is_instance_of::<PyUnicode>()? {
return Err(PyTypeError::new_err(format!(
"expected str, got {}",
haystack.get_type().name()?
)));
}
} else if !haystack.is_instance_of::<PyBytes>()?
&& !haystack.is_instance_of::<PyByteArray>()?
{
return Err(PyTypeError::new_err(format!(
"expected bytes or bytearray, got {}",
haystack.get_type().name()?
)));
};

let haystack_bytes = if self_.is_unicode {
haystack.extract::<&str>()?.as_bytes()
} else {
haystack.extract()?
};

let matches = self_.get_matches(py, haystack_bytes, overlapping);
if self_.is_unicode {
// Map UTF-8 byte index to Unicode code point index; the latter is what
// Python users expect.
let mut byte_to_code_point = vec![usize::MAX; haystack_bytes.len() + 1];
let mut max_codepoint = 0;
for (codepoint_off, (byte_off, _)) in
haystack.extract::<&str>()?.char_indices().enumerate()
{
byte_to_code_point[byte_off] = codepoint_off;
max_codepoint = codepoint_off;
}
// End index is exclusive (e.g. 0:3 is first 3 characters), so handle
// the case where pattern is at end of string.
if !haystack_bytes.is_empty() {
byte_to_code_point[haystack_bytes.len()] = max_codepoint + 1;
}

Ok(matches
.into_iter()
.map(|m| {
(
m.pattern(),
byte_to_code_point[m.start()],
byte_to_code_point[m.end()],
)
})
.collect())
} else {
Ok(matches
.into_iter()
.map(|m| (m.pattern(), m.start(), m.end()))
.collect())
}
}

/// Return matches as list of patterns.
#[args(overlapping = "false")]
fn find_matches_as_strings(
self_: PyRef<Self>,
haystack: &str,
haystack: Py<PyAny>,
overlapping: bool,
) -> PyResult<Vec<Py<PyUnicode>>> {
) -> PyResult<Vec<Py<PyAny>>> {
self_.check_overlapping(overlapping)?;

let py = self_.py();
let matches = self_.get_matches(py, haystack, overlapping);
let haystack = haystack.as_ref(py);
if self_.is_unicode {
if !haystack.is_instance_of::<PyUnicode>()? {
return Err(PyTypeError::new_err(format!(
"expected str, got {}",
haystack.get_type().name()?
)));
}
} else if !haystack.is_instance_of::<PyBytes>()?
&& !haystack.is_instance_of::<PyByteArray>()?
{
return Err(PyTypeError::new_err(format!(
"expected bytes or bytearray, got {}",
haystack.get_type().name()?
)));
}

let haystack_bytes = if self_.is_unicode {
haystack.extract::<&str>()?.as_bytes()
} else {
haystack.extract()?
};

let matches = self_.get_matches(py, haystack_bytes, overlapping);
Ok(matches
.into_iter()
.map(|m| self_.patterns[m.pattern()].clone_ref(py))
Expand Down

0 comments on commit 4ca8c26

Please sign in to comment.