diff --git a/Cargo.lock b/Cargo.lock index 7bf9527f47d12..ffc83d80da154 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3976,6 +3976,7 @@ dependencies = [ "arrayvec", "rustc_macros", "rustc_serialize", + "smallvec", ] [[package]] diff --git a/compiler/rustc_borrowck/src/region_infer/values.rs b/compiler/rustc_borrowck/src/region_infer/values.rs index 100ac578f92dd..4a70535c63bea 100644 --- a/compiler/rustc_borrowck/src/region_infer/values.rs +++ b/compiler/rustc_borrowck/src/region_infer/values.rs @@ -1,5 +1,7 @@ use rustc_data_structures::fx::FxIndexSet; -use rustc_index::bit_set::{HybridBitSet, SparseBitMatrix}; +use rustc_index::bit_set::SparseBitMatrix; +use rustc_index::interval::IntervalSet; +use rustc_index::interval::SparseIntervalMatrix; use rustc_index::vec::Idx; use rustc_index::vec::IndexVec; use rustc_middle::mir::{BasicBlock, Body, Location}; @@ -110,11 +112,11 @@ crate enum RegionElement { PlaceholderRegion(ty::PlaceholderRegion), } -/// When we initially compute liveness, we use a bit matrix storing -/// points for each region-vid. +/// When we initially compute liveness, we use an interval matrix storing +/// liveness ranges for each region-vid. crate struct LivenessValues { elements: Rc, - points: SparseBitMatrix, + points: SparseIntervalMatrix, } impl LivenessValues { @@ -122,7 +124,7 @@ impl LivenessValues { /// Each of the regions in num_region_variables will be initialized with an /// empty set of points and no causal information. crate fn new(elements: Rc) -> Self { - Self { points: SparseBitMatrix::new(elements.num_points), elements } + Self { points: SparseIntervalMatrix::new(elements.num_points), elements } } /// Iterate through each region that has a value in this set. @@ -140,7 +142,7 @@ impl LivenessValues { /// Adds all the elements in the given bit array into the given /// region. Returns whether any of them are newly added. - crate fn add_elements(&mut self, row: N, locations: &HybridBitSet) -> bool { + crate fn add_elements(&mut self, row: N, locations: &IntervalSet) -> bool { debug!("LivenessValues::add_elements(row={:?}, locations={:?})", row, locations); self.points.union_row(row, locations) } @@ -153,7 +155,7 @@ impl LivenessValues { /// Returns `true` if the region `r` contains the given element. crate fn contains(&self, row: N, location: Location) -> bool { let index = self.elements.point_from_location(location); - self.points.contains(row, index) + self.points.row(row).map_or(false, |r| r.contains(index)) } /// Returns an iterator of all the elements contained by the region `r` @@ -221,7 +223,7 @@ impl PlaceholderIndices { crate struct RegionValues { elements: Rc, placeholder_indices: Rc, - points: SparseBitMatrix, + points: SparseIntervalMatrix, free_regions: SparseBitMatrix, /// Placeholders represent bound regions -- so something like `'a` @@ -241,7 +243,7 @@ impl RegionValues { let num_placeholders = placeholder_indices.len(); Self { elements: elements.clone(), - points: SparseBitMatrix::new(elements.num_points), + points: SparseIntervalMatrix::new(elements.num_points), placeholder_indices: placeholder_indices.clone(), free_regions: SparseBitMatrix::new(num_universal_regions), placeholders: SparseBitMatrix::new(num_placeholders), diff --git a/compiler/rustc_borrowck/src/type_check/liveness/trace.rs b/compiler/rustc_borrowck/src/type_check/liveness/trace.rs index 0969b9a508f1d..094af20f52efc 100644 --- a/compiler/rustc_borrowck/src/type_check/liveness/trace.rs +++ b/compiler/rustc_borrowck/src/type_check/liveness/trace.rs @@ -1,5 +1,6 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_index::bit_set::HybridBitSet; +use rustc_index::interval::IntervalSet; use rustc_infer::infer::canonical::QueryRegionConstraints; use rustc_middle::mir::{BasicBlock, Body, ConstraintCategory, Local, Location}; use rustc_middle::ty::{Ty, TypeFoldable}; @@ -105,12 +106,12 @@ struct LivenessResults<'me, 'typeck, 'flow, 'tcx> { /// Points where the current variable is "use live" -- meaning /// that there is a future "full use" that may use its value. - use_live_at: HybridBitSet, + use_live_at: IntervalSet, /// Points where the current variable is "drop live" -- meaning /// that there is no future "full use" that may use its value, but /// there is a future drop. - drop_live_at: HybridBitSet, + drop_live_at: IntervalSet, /// Locations where drops may occur. drop_locations: Vec, @@ -125,8 +126,8 @@ impl<'me, 'typeck, 'flow, 'tcx> LivenessResults<'me, 'typeck, 'flow, 'tcx> { LivenessResults { cx, defs: HybridBitSet::new_empty(num_points), - use_live_at: HybridBitSet::new_empty(num_points), - drop_live_at: HybridBitSet::new_empty(num_points), + use_live_at: IntervalSet::new(num_points), + drop_live_at: IntervalSet::new(num_points), drop_locations: vec![], stack: vec![], } @@ -165,7 +166,7 @@ impl<'me, 'typeck, 'flow, 'tcx> LivenessResults<'me, 'typeck, 'flow, 'tcx> { drop_used: Vec<(Local, Location)>, live_locals: FxHashSet, ) { - let locations = HybridBitSet::new_empty(self.cx.elements.num_points()); + let locations = IntervalSet::new(self.cx.elements.num_points()); for (local, location) in drop_used { if !live_locals.contains(&local) { @@ -456,7 +457,7 @@ impl<'tcx> LivenessContext<'_, '_, '_, 'tcx> { fn add_use_live_facts_for( &mut self, value: impl TypeFoldable<'tcx>, - live_at: &HybridBitSet, + live_at: &IntervalSet, ) { debug!("add_use_live_facts_for(value={:?})", value); @@ -473,7 +474,7 @@ impl<'tcx> LivenessContext<'_, '_, '_, 'tcx> { dropped_local: Local, dropped_ty: Ty<'tcx>, drop_locations: &[Location], - live_at: &HybridBitSet, + live_at: &IntervalSet, ) { debug!( "add_drop_live_constraint(\ @@ -521,7 +522,7 @@ impl<'tcx> LivenessContext<'_, '_, '_, 'tcx> { elements: &RegionValueElements, typeck: &mut TypeChecker<'_, 'tcx>, value: impl TypeFoldable<'tcx>, - live_at: &HybridBitSet, + live_at: &IntervalSet, ) { debug!("make_all_regions_live(value={:?})", value); debug!( diff --git a/compiler/rustc_index/Cargo.toml b/compiler/rustc_index/Cargo.toml index b984a1321e0aa..89419bfce6f5b 100644 --- a/compiler/rustc_index/Cargo.toml +++ b/compiler/rustc_index/Cargo.toml @@ -10,3 +10,4 @@ doctest = false arrayvec = { version = "0.7", default-features = false } rustc_serialize = { path = "../rustc_serialize" } rustc_macros = { path = "../rustc_macros" } +smallvec = "1" diff --git a/compiler/rustc_index/src/interval.rs b/compiler/rustc_index/src/interval.rs new file mode 100644 index 0000000000000..6da95053b116d --- /dev/null +++ b/compiler/rustc_index/src/interval.rs @@ -0,0 +1,269 @@ +use std::iter::Step; +use std::marker::PhantomData; +use std::ops::Bound; +use std::ops::RangeBounds; + +use crate::vec::Idx; +use crate::vec::IndexVec; +use smallvec::SmallVec; + +#[cfg(test)] +mod tests; + +/// Stores a set of intervals on the indices. +#[derive(Debug, Clone)] +pub struct IntervalSet { + // Start, end + map: SmallVec<[(u32, u32); 4]>, + domain: usize, + _data: PhantomData, +} + +#[inline] +fn inclusive_start(range: impl RangeBounds) -> u32 { + match range.start_bound() { + Bound::Included(start) => start.index() as u32, + Bound::Excluded(start) => start.index() as u32 + 1, + Bound::Unbounded => 0, + } +} + +#[inline] +fn inclusive_end(domain: usize, range: impl RangeBounds) -> Option { + let end = match range.end_bound() { + Bound::Included(end) => end.index() as u32, + Bound::Excluded(end) => end.index().checked_sub(1)? as u32, + Bound::Unbounded => domain.checked_sub(1)? as u32, + }; + Some(end) +} + +impl IntervalSet { + pub fn new(domain: usize) -> IntervalSet { + IntervalSet { map: SmallVec::new(), domain, _data: PhantomData } + } + + pub fn clear(&mut self) { + self.map.clear(); + } + + pub fn iter(&self) -> impl Iterator + '_ + where + I: Step, + { + self.iter_intervals().flatten() + } + + /// Iterates through intervals stored in the set, in order. + pub fn iter_intervals(&self) -> impl Iterator> + '_ + where + I: Step, + { + self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) + } + + /// Returns true if we increased the number of elements present. + pub fn insert(&mut self, point: I) -> bool { + self.insert_range(point..=point) + } + + /// Returns true if we increased the number of elements present. + pub fn insert_range(&mut self, range: impl RangeBounds + Clone) -> bool { + let start = inclusive_start(range.clone()); + let Some(mut end) = inclusive_end(self.domain, range) else { + // empty range + return false; + }; + if start > end { + return false; + } + + loop { + // This condition looks a bit weird, but actually makes sense. + // + // if r.0 == end + 1, then we're actually adjacent, so we want to + // continue to the next range. We're looking here for the first + // range which starts *non-adjacently* to our end. + let next = self.map.partition_point(|r| r.0 <= end + 1); + if let Some(last) = next.checked_sub(1) { + let (prev_start, prev_end) = &mut self.map[last]; + if *prev_end + 1 >= start { + // If the start for the inserted range is adjacent to the + // end of the previous, we can extend the previous range. + if start < *prev_start { + // Our range starts before the one we found. We'll need + // to *remove* it, and then try again. + // + // FIXME: This is not so efficient; we may need to + // recurse a bunch of times here. Instead, it's probably + // better to do something like drain_filter(...) on the + // map to be able to delete or modify all the ranges in + // start..=end and then potentially re-insert a new + // range. + end = std::cmp::max(end, *prev_end); + self.map.remove(last); + } else { + // We overlap with the previous range, increase it to + // include us. + // + // Make sure we're actually going to *increase* it though -- + // it may be that end is just inside the previously existing + // set. + return if end > *prev_end { + *prev_end = end; + true + } else { + false + }; + } + } else { + // Otherwise, we don't overlap, so just insert + self.map.insert(last + 1, (start, end)); + return true; + } + } else { + if self.map.is_empty() { + // Quite common in practice, and expensive to call memcpy + // with length zero. + self.map.push((start, end)); + } else { + self.map.insert(next, (start, end)); + } + return true; + } + } + } + + pub fn contains(&self, needle: I) -> bool { + let needle = needle.index() as u32; + let last = match self.map.partition_point(|r| r.0 <= needle).checked_sub(1) { + Some(idx) => idx, + None => { + // All ranges in the map start after the new range's end + return false; + } + }; + let (_, prev_end) = &self.map[last]; + needle <= *prev_end + } + + pub fn superset(&self, other: &IntervalSet) -> bool + where + I: Step, + { + // FIXME: Performance here is probably not great. We will be doing a lot + // of pointless tree traversals. + other.iter().all(|elem| self.contains(elem)) + } + + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + + /// Returns the maximum (last) element present in the set from `range`. + pub fn last_set_in(&self, range: impl RangeBounds + Clone) -> Option { + let start = inclusive_start(range.clone()); + let Some(end) = inclusive_end(self.domain, range) else { + // empty range + return None; + }; + if start > end { + return None; + } + let last = match self.map.partition_point(|r| r.0 <= end).checked_sub(1) { + Some(idx) => idx, + None => { + // All ranges in the map start after the new range's end + return None; + } + }; + let (_, prev_end) = &self.map[last]; + if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None } + } + + pub fn insert_all(&mut self) { + self.clear(); + self.map.push((0, self.domain.try_into().unwrap())); + } + + pub fn union(&mut self, other: &IntervalSet) -> bool + where + I: Step, + { + assert_eq!(self.domain, other.domain); + let mut did_insert = false; + for range in other.iter_intervals() { + did_insert |= self.insert_range(range); + } + did_insert + } +} + +/// This data structure optimizes for cases where the stored bits in each row +/// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast +/// to BitMatrix and SparseBitMatrix which are optimized for +/// "random"/non-contiguous bits and cheap(er) point queries at the expense of +/// memory usage. +#[derive(Clone)] +pub struct SparseIntervalMatrix +where + R: Idx, + C: Idx, +{ + rows: IndexVec>, + column_size: usize, +} + +impl SparseIntervalMatrix { + pub fn new(column_size: usize) -> SparseIntervalMatrix { + SparseIntervalMatrix { rows: IndexVec::new(), column_size } + } + + pub fn rows(&self) -> impl Iterator { + self.rows.indices() + } + + pub fn row(&self, row: R) -> Option<&IntervalSet> { + self.rows.get(row) + } + + fn ensure_row(&mut self, row: R) -> &mut IntervalSet { + self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)); + &mut self.rows[row] + } + + pub fn union_row(&mut self, row: R, from: &IntervalSet) -> bool + where + C: Step, + { + self.ensure_row(row).union(from) + } + + pub fn union_rows(&mut self, read: R, write: R) -> bool + where + C: Step, + { + if read == write || self.rows.get(read).is_none() { + return false; + } + self.ensure_row(write); + let (read_row, write_row) = self.rows.pick2_mut(read, write); + write_row.union(read_row) + } + + pub fn insert_all_into_row(&mut self, row: R) { + self.ensure_row(row).insert_all(); + } + + pub fn insert_range(&mut self, row: R, range: impl RangeBounds + Clone) { + self.ensure_row(row).insert_range(range); + } + + pub fn insert(&mut self, row: R, point: C) -> bool { + self.ensure_row(row).insert(point) + } + + pub fn contains(&self, row: R, point: C) -> bool { + self.row(row).map_or(false, |r| r.contains(point)) + } +} diff --git a/compiler/rustc_index/src/interval/tests.rs b/compiler/rustc_index/src/interval/tests.rs new file mode 100644 index 0000000000000..d90b449f32609 --- /dev/null +++ b/compiler/rustc_index/src/interval/tests.rs @@ -0,0 +1,199 @@ +use super::*; + +#[test] +fn insert_collapses() { + let mut set = IntervalSet::::new(3000); + set.insert_range(9831..=9837); + set.insert_range(43..=9830); + assert_eq!(set.iter_intervals().collect::>(), [43..9838]); +} + +#[test] +fn contains() { + let mut set = IntervalSet::new(300); + set.insert(0u32); + assert!(set.contains(0)); + set.insert_range(0..10); + assert!(set.contains(9)); + assert!(!set.contains(10)); + set.insert_range(10..11); + assert!(set.contains(10)); +} + +#[test] +fn insert() { + for i in 0..30usize { + let mut set = IntervalSet::new(300); + for j in i..30usize { + set.insert(j); + for k in i..j { + assert!(set.contains(k)); + } + } + } + + let mut set = IntervalSet::new(300); + set.insert_range(0..1u32); + assert!(set.contains(0), "{:?}", set.map); + assert!(!set.contains(1)); + set.insert_range(1..1); + assert!(set.contains(0)); + assert!(!set.contains(1)); + + let mut set = IntervalSet::new(300); + set.insert_range(4..5u32); + set.insert_range(5..10); + assert_eq!(set.iter().collect::>(), [4, 5, 6, 7, 8, 9]); + set.insert_range(3..7); + assert_eq!(set.iter().collect::>(), [3, 4, 5, 6, 7, 8, 9]); + + let mut set = IntervalSet::new(300); + set.insert_range(0..10u32); + set.insert_range(3..5); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + + let mut set = IntervalSet::new(300); + set.insert_range(0..10u32); + set.insert_range(0..3); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + + let mut set = IntervalSet::new(300); + set.insert_range(0..10u32); + set.insert_range(0..10); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + + let mut set = IntervalSet::new(300); + set.insert_range(0..10u32); + set.insert_range(5..10); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + + let mut set = IntervalSet::new(300); + set.insert_range(0..10u32); + set.insert_range(5..13); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); +} + +#[test] +fn insert_range() { + #[track_caller] + fn check(range: R) + where + R: RangeBounds + Clone + IntoIterator + std::fmt::Debug, + { + let mut set = IntervalSet::new(300); + set.insert_range(range.clone()); + for i in set.iter() { + assert!(range.contains(&i)); + } + for i in range.clone() { + assert!(set.contains(i), "A: {} in {:?}, inserted {:?}", i, set, range); + } + set.insert_range(range.clone()); + for i in set.iter() { + assert!(range.contains(&i), "{} in {:?}", i, set); + } + for i in range.clone() { + assert!(set.contains(i), "B: {} in {:?}, inserted {:?}", i, set, range); + } + } + check(10..10); + check(10..100); + check(10..30); + check(0..5); + check(0..250); + check(200..250); + + check(10..=10); + check(10..=100); + check(10..=30); + check(0..=5); + check(0..=250); + check(200..=250); + + for i in 0..30 { + for j in i..30 { + check(i..j); + check(i..=j); + } + } +} + +#[test] +fn insert_range_dual() { + let mut set = IntervalSet::::new(300); + set.insert_range(0..3); + assert_eq!(set.iter().collect::>(), [0, 1, 2]); + set.insert_range(5..7); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 5, 6]); + set.insert_range(3..4); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 5, 6]); + set.insert_range(3..5); + assert_eq!(set.iter().collect::>(), [0, 1, 2, 3, 4, 5, 6]); +} + +#[test] +fn last_set_before_adjacent() { + let mut set = IntervalSet::::new(300); + set.insert_range(0..3); + set.insert_range(3..5); + assert_eq!(set.last_set_in(0..3), Some(2)); + assert_eq!(set.last_set_in(0..5), Some(4)); + assert_eq!(set.last_set_in(3..5), Some(4)); + set.insert_range(2..5); + assert_eq!(set.last_set_in(0..3), Some(2)); + assert_eq!(set.last_set_in(0..5), Some(4)); + assert_eq!(set.last_set_in(3..5), Some(4)); +} + +#[test] +fn last_set_in() { + fn easy(set: &IntervalSet, needle: impl RangeBounds) -> Option { + let mut last_leq = None; + for e in set.iter() { + if needle.contains(&e) { + last_leq = Some(e); + } + } + last_leq + } + + #[track_caller] + fn cmp(set: &IntervalSet, needle: impl RangeBounds + Clone + std::fmt::Debug) { + assert_eq!( + set.last_set_in(needle.clone()), + easy(set, needle.clone()), + "{:?} in {:?}", + needle, + set + ); + } + let mut set = IntervalSet::new(300); + cmp(&set, 50..=50); + set.insert(64); + cmp(&set, 64..=64); + set.insert(64 - 1); + cmp(&set, 0..=64 - 1); + cmp(&set, 0..=5); + cmp(&set, 10..100); + set.insert(100); + cmp(&set, 100..110); + cmp(&set, 99..100); + cmp(&set, 99..=100); + + for i in 0..=30 { + for j in i..=30 { + for k in 0..30 { + let mut set = IntervalSet::new(100); + cmp(&set, ..j); + cmp(&set, i..); + cmp(&set, i..j); + cmp(&set, i..=j); + set.insert(k); + cmp(&set, ..j); + cmp(&set, i..); + cmp(&set, i..j); + cmp(&set, i..=j); + } + } + } +} diff --git a/compiler/rustc_index/src/lib.rs b/compiler/rustc_index/src/lib.rs index a9efd6bb8bc8d..359b1859c6889 100644 --- a/compiler/rustc_index/src/lib.rs +++ b/compiler/rustc_index/src/lib.rs @@ -7,6 +7,7 @@ #![feature(let_else)] pub mod bit_set; +pub mod interval; pub mod vec; // FIXME(#56935): Work around ICEs during cross-compilation.