Skip to content

Commit

Permalink
syntax: optimize most of the IntervalSet routines
Browse files Browse the repository at this point in the history
This reduces or eliminates allocation when combining Unicode classes and
should make some things faster. It's unlikely for these optimizations to
matter much in practice, but they are likely to help in niche or
pathological cases where there are a lot of ops in a class.

Closes #1051
  • Loading branch information
Licheam authored and BurntSushi committed Oct 8, 2023
1 parent 122adbf commit c8fe16b
Showing 1 changed file with 185 additions and 97 deletions.
282 changes: 185 additions & 97 deletions regex-syntax/src/hir/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::unicode;
//
// Some of the implementation complexity here is a result of me wanting to
// preserve the sequential representation without using additional memory.
// In many cases, we do use linear extra memory, but it is at most 2x and it
// In some cases, we do use linear extra memory, but it is at most 2x and it
// is amortized. If we relaxed the memory requirements, this implementation
// could become much simpler. The extra memory is honestly probably OK, but
// character classes (especially of the Unicode variety) can become quite
Expand Down Expand Up @@ -81,14 +81,45 @@ impl<I: Interval> IntervalSet<I> {

/// Add a new interval to this set.
pub fn push(&mut self, interval: I) {
// TODO: This could be faster. e.g., Push the interval such that
// it preserves canonicalization.
self.ranges.push(interval);
self.canonicalize();
// We don't know whether the new interval added here is considered
// case folded, so we conservatively assume that the entire set is
// no longer case folded if it was previously.
self.folded = false;

if self.ranges.is_empty() {
self.ranges.push(interval);
return;
}

// Find the first range that is not greater than the new interval.
// This is the first range that could possibly be unioned with the
// new interval.
let mut drain_end = self.ranges.len();
while drain_end > 0
&& self.ranges[drain_end - 1].lower() > interval.upper()
&& !self.ranges[drain_end - 1].is_contiguous(&interval)
{
drain_end -= 1;
}

// Try to union the new interval with old intervals backwards.
if drain_end > 0 && self.ranges[drain_end - 1].is_contiguous(&interval)
{
self.ranges[drain_end - 1] =
self.ranges[drain_end - 1].union(&interval).unwrap();
for i in (0..drain_end - 1).rev() {
if let Some(union) =
self.ranges[drain_end - 1].union(&self.ranges[i])
{
self.ranges[drain_end - 1] = union;
} else {
self.ranges.drain(i + 1..drain_end - 1);
break;
}
}
} else {
self.ranges.insert(drain_end, interval);
}
}

/// Return an iterator over all intervals in this set.
Expand Down Expand Up @@ -192,34 +223,13 @@ impl<I: Interval> IntervalSet<I> {
// Folks seem to suggest interval or segment trees, but I'd like to
// avoid the overhead (both runtime and conceptual) of that.
//
// The following is basically my Shitty First Draft. Therefore, in
// order to grok it, you probably need to read each line carefully.
// Simplifications are most welcome!
//
// Remember, we can assume the canonical format invariant here, which
// says that all ranges are sorted, not overlapping and not adjacent in
// each class.
let drain_end = self.ranges.len();
let (mut a, mut b) = (0, 0);
'LOOP: while a < drain_end && b < other.ranges.len() {
// Basically, the easy cases are when neither range overlaps with
// each other. If the `b` range is less than our current `a`
// range, then we can skip it and move on.
if other.ranges[b].upper() < self.ranges[a].lower() {
b += 1;
continue;
}
// ... similarly for the `a` range. If it's less than the smallest
// `b` range, then we can add it as-is.
if self.ranges[a].upper() < other.ranges[b].lower() {
let range = self.ranges[a];
self.ranges.push(range);
a += 1;
continue;
}
// Otherwise, we have overlapping ranges.
assert!(!self.ranges[a].is_intersection_empty(&other.ranges[b]));

let mut b = 0;
for a in 0..drain_end {
// This part is tricky and was non-obvious to me without looking
// at explicit examples (see the tests). The trickiness stems from
// two things: 1) subtracting a range from another range could
Expand All @@ -231,47 +241,34 @@ impl<I: Interval> IntervalSet<I> {
// For example, if our `a` range is `a-t` and our next three `b`
// ranges are `a-c`, `g-i`, `r-t` and `x-z`, then we need to apply
// subtraction three times before moving on to the next `a` range.
let mut range = self.ranges[a];
self.ranges.push(self.ranges[a]);
// Only when `b` is not above `a`, `b` might apply to current
// `a` range.
while b < other.ranges.len()
&& !range.is_intersection_empty(&other.ranges[b])
&& other.ranges[b].lower() <= self.ranges[a].upper()
{
let old_range = range;
range = match range.difference(&other.ranges[b]) {
(None, None) => {
// We lost the entire range, so move on to the next
// without adding this one.
a += 1;
continue 'LOOP;
match self.ranges.pop().unwrap().difference(&other.ranges[b]) {
(Some(range1), None) | (None, Some(range1)) => {
self.ranges.push(range1);
}
(Some(range1), None) | (None, Some(range1)) => range1,
(Some(range1), Some(range2)) => {
self.ranges.push(range1);
range2
self.ranges.push(range2);
}
};
// It's possible that the `b` range has more to contribute
// here. In particular, if it is greater than the original
// range, then it might impact the next `a` range *and* it
// has impacted the current `a` range as much as possible,
// so we can quit. We don't bump `b` so that the next `a`
// range can apply it.
if other.ranges[b].upper() > old_range.upper() {
break;
(None, None) => {}
}
// Otherwise, the next `b` range might apply to the current
// The next `b` range might apply to the current
// `a` range.
b += 1;
}
self.ranges.push(range);
a += 1;
}
while a < drain_end {
let range = self.ranges[a];
self.ranges.push(range);
a += 1;
// It's possible that the last `b` range has more to
// contribute to the next `a`. We don't bump the last
// `b` so that the next `a` range can apply it.
b = b.saturating_sub(1);
}

self.ranges.drain(..drain_end);
self.folded = self.folded && other.folded;
self.folded = self.ranges.is_empty() || (self.folded && other.folded);
}

/// Compute the symmetric difference of the two sets, in place.
Expand All @@ -282,11 +279,83 @@ impl<I: Interval> IntervalSet<I> {
/// set. That is, the set will contain all elements in either set,
/// but will not contain any elements that are in both sets.
pub fn symmetric_difference(&mut self, other: &IntervalSet<I>) {
// TODO(burntsushi): Fix this so that it amortizes allocation.
let mut intersection = self.clone();
intersection.intersect(other);
self.union(other);
self.difference(&intersection);
if self.ranges.is_empty() {
self.ranges.extend(&other.ranges);
self.folded = other.folded;
return;
}
if other.ranges.is_empty() {
return;
}

// There should be a way to do this in-place with constant memory,
// but I couldn't figure out a simple way to do it. So just append
// the symmetric difference to the end of this range, and then drain
// it before we're done.
let drain_end = self.ranges.len();
let mut b = 0;
let mut b_range = Some(other.ranges[b]);
for a in 0..drain_end {
self.ranges.push(self.ranges[a]);
while b_range
.map_or(false, |r| r.lower() <= self.ranges[a].upper())
{
let (range1, range2) = match self
.ranges
.pop()
.unwrap()
.symmetric_difference(&b_range.as_ref().unwrap())
{
(Some(range1), None) | (None, Some(range1)) => {
(Some(range1), None)
}
(Some(range1), Some(range2)) => {
(Some(range1), Some(range2))
}
(None, None) => (None, None),
};
if let Some(range) = range1 {
if self.ranges.len() > drain_end
&& self.ranges.last().unwrap().is_contiguous(&range)
{
self.ranges
.last_mut()
.map(|last| *last = last.union(&range).unwrap());
} else {
self.ranges.push(range);
}
}
if let Some(range) = range2 {
self.ranges.push(range);
}

b_range = if self.ranges.len() > drain_end
&& self.ranges.last().unwrap().upper()
> self.ranges[a].upper()
{
Some(*self.ranges.last().unwrap())
} else {
b += 1;
other.ranges.get(b).cloned()
};
}
}
while let Some(range) = b_range {
if self.ranges.len() > drain_end
&& self.ranges.last().unwrap().is_contiguous(&range)
{
self.ranges
.last_mut()
.map(|last| *last = last.union(&range).unwrap());
} else {
self.ranges.push(range);
}
b += 1;
b_range = other.ranges.get(b).cloned();
}

self.ranges.drain(..drain_end);
self.folded = self.ranges.is_empty() || (self.folded && other.folded);
}

/// Negate this interval set.
Expand All @@ -302,28 +371,44 @@ impl<I: Interval> IntervalSet<I> {
return;
}

// There should be a way to do this in-place with constant memory,
// but I couldn't figure out a simple way to do it. So just append
// the negation to the end of this range, and then drain it before
// we're done.
let drain_end = self.ranges.len();

// We do checked arithmetic below because of the canonical ordering
// invariant.
if self.ranges[0].lower() > I::Bound::min_value() {
let upper = self.ranges[0].lower().decrement();
self.ranges.push(I::create(I::Bound::min_value(), upper));
}
for i in 1..drain_end {
let lower = self.ranges[i - 1].upper().increment();
let upper = self.ranges[i].lower().decrement();
self.ranges.push(I::create(lower, upper));
}
if self.ranges[drain_end - 1].upper() < I::Bound::max_value() {
let lower = self.ranges[drain_end - 1].upper().increment();
self.ranges.push(I::create(lower, I::Bound::max_value()));
let mut pre_upper = self.ranges[0].upper();
self.ranges[0] = I::create(
I::Bound::min_value(),
self.ranges[0].lower().decrement(),
);
for i in 1..self.ranges.len() {
let lower = pre_upper.increment();
pre_upper = self.ranges[i].upper();
self.ranges[i] =
I::create(lower, self.ranges[i].lower().decrement());
}
if pre_upper < I::Bound::max_value() {
self.ranges.push(I::create(
pre_upper.increment(),
I::Bound::max_value(),
));
}
} else {
for i in 1..self.ranges.len() {
self.ranges[i - 1] = I::create(
self.ranges[i - 1].upper().increment(),
self.ranges[i].lower().decrement(),
);
}
if self.ranges.last().unwrap().upper() < I::Bound::max_value() {
self.ranges.last_mut().map(|range| {
*range = I::create(
range.upper().increment(),
I::Bound::max_value(),
)
});
} else {
self.ranges.pop();
}
}
self.ranges.drain(..drain_end);
// We don't need to update whether this set is folded or not, because
// it is conservatively preserved through negation. Namely, if a set
// is not folded, then it is possible that its negation is folded, for
Expand All @@ -337,6 +422,7 @@ impl<I: Interval> IntervalSet<I> {
// of case folded characters. Negating it in turn means that all
// equivalence classes in the set are negated, and any equivalence
// class that was previously not in the set is now entirely in the set.
self.folded = self.ranges.is_empty() || self.folded;
}

/// Converts this set into a canonical ordering.
Expand All @@ -347,24 +433,20 @@ impl<I: Interval> IntervalSet<I> {
self.ranges.sort();
assert!(!self.ranges.is_empty());

// Is there a way to do this in-place with constant memory? I couldn't
// figure out a way to do it. So just append the canonicalization to
// the end of this range, and then drain it before we're done.
let drain_end = self.ranges.len();
for oldi in 0..drain_end {
// If we've added at least one new range, then check if we can
// merge this range in the previously added range.
if self.ranges.len() > drain_end {
let (last, rest) = self.ranges.split_last_mut().unwrap();
if let Some(union) = last.union(&rest[oldi]) {
*last = union;
continue;
}
// We maintain the canonicalization results in-place at `0..newi`.
// `newi` will keep track of the end of the canonicalized ranges.
let mut newi = 0;
for oldi in 1..self.ranges.len() {
// The last new range gets merged with currnet old range when
// unionable. If not, we update `newi` and store it as a new range.
if let Some(union) = self.ranges[newi].union(&self.ranges[oldi]) {
self.ranges[newi] = union;
} else {
newi += 1;
self.ranges[newi] = self.ranges[oldi];
}
let range = self.ranges[oldi];
self.ranges.push(range);
}
self.ranges.drain(..drain_end);
self.ranges.truncate(newi + 1);
}

/// Returns true if and only if this class is in a canonical ordering.
Expand Down Expand Up @@ -486,7 +568,13 @@ pub trait Interval:
other: &Self,
) -> (Option<Self>, Option<Self>) {
let union = match self.union(other) {
None => return (Some(self.clone()), Some(other.clone())),
None => {
return if self.upper() < other.lower() {
(Some(self.clone()), Some(other.clone()))
} else {
(Some(other.clone()), Some(self.clone()))
}
}
Some(union) => union,
};
let intersection = match self.intersect(other) {
Expand Down

0 comments on commit c8fe16b

Please sign in to comment.