diff --git a/src/bitmap/container.rs b/src/bitmap/container.rs index 89b97156..bf4287c5 100644 --- a/src/bitmap/container.rs +++ b/src/bitmap/container.rs @@ -1,4 +1,5 @@ -use std::{fmt, ops::Range}; +use std::fmt; +use std::ops::RangeInclusive; use super::store::{self, Store}; use super::util; @@ -38,13 +39,7 @@ impl Container { } } - pub fn insert_range(&mut self, range: Range) -> u64 { - // If the range is larger than the array limit, skip populating the - // array to then have to convert it to a bitmap anyway. - if matches!(self.store, Store::Array(_)) && range.end - range.start > ARRAY_LIMIT as u16 { - self.store = self.store.to_bitmap() - } - + pub fn insert_range(&mut self, range: RangeInclusive) -> u64 { let inserted = self.store.insert_range(range); self.len += inserted; self.ensure_correct_store(); @@ -68,12 +63,8 @@ impl Container { } } - pub fn remove_range(&mut self, start: u32, end: u32) -> u64 { - debug_assert!(start <= end); - if start == end { - return 0; - } - let result = self.store.remove_range(start, end); + pub fn remove_range(&mut self, range: RangeInclusive) -> u64 { + let result = self.store.remove_range(range); self.len -= result; self.ensure_correct_store(); result diff --git a/src/bitmap/inherent.rs b/src/bitmap/inherent.rs index 2c1d26f2..f88732ae 100644 --- a/src/bitmap/inherent.rs +++ b/src/bitmap/inherent.rs @@ -1,4 +1,4 @@ -use std::ops::{Bound, RangeBounds}; +use std::ops::RangeBounds; use crate::RoaringBitmap; @@ -44,7 +44,21 @@ impl RoaringBitmap { container.insert(index) } - /// Inserts a range of values from the set. + /// Search for the specific container by the given key. + /// Create a new container if not exist. + /// + /// Return the index of the target container. + fn find_container_by_key(&mut self, key: u16) -> usize { + match self.containers.binary_search_by_key(&key, |c| c.key) { + Ok(loc) => loc, + Err(loc) => { + self.containers.insert(loc, Container::new(key)); + loc + } + } + } + + /// Inserts a range of values. /// Returns the number of inserted values. /// /// # Examples @@ -58,118 +72,54 @@ impl RoaringBitmap { /// assert!(rb.contains(3)); /// assert!(!rb.contains(4)); /// ``` - pub fn insert_range>(&mut self, range: R) -> u64 { - use Bound::{Included, Excluded, Unbounded}; - use util::split; - - let range = match (range.start_bound(), range.end_bound()) { - (Included(start), Included(end)) => todo!(), - (Included(start), Excluded(end)) => todo!(), - (Included(start), Unbounded) => todo!(), - (Excluded(start), Included(end)) => todo!(), - (Excluded(start), Excluded(end)) => todo!(), - (Excluded(start), Unbounded) => todo!(), - (Unbounded, Included(end)) => todo!(), - (Unbounded, Excluded(end)) => todo!(), - (Unbounded, Unbounded) => split(0)..=split(u32::max_value()), - }; + pub fn insert_range(&mut self, range: R) -> u64 + where + R: RangeBounds, + { + let (start, end); + if let Some(range) = util::convert_range_to_inclusive(range) { + start = *range.start(); + end = *range.end(); + } else { + return 0; + } + + let (start_container_key, start_index) = util::split(start); + let (end_container_key, end_index) = util::split(end); + + // Find the container index for start_container_key + let first_index = self.find_container_by_key(start_container_key); + + // If the end range value is in the same container, just call into + // the one container. + if start_container_key == end_container_key { + return self.containers[first_index].insert_range(start_index..=end_index); + } + + // For the first container, insert start_index..=u16::MAX, with + // subsequent containers inserting 0..MAX. + // + // The last container (end_container_key) is handled explicitly outside + // the loop. + let mut low = start_index; + let mut inserted = 0; + + for i in start_container_key..end_container_key { + let index = self.find_container_by_key(i); + + // Insert the range subset for this container + inserted += self.containers[index].insert_range(low..=u16::MAX); + + // After the first container, always fill the containers. + low = 0; + } + + // Handle the last container + let last_index = self.find_container_by_key(end_container_key); - // let start = match range.start_bound() { - // Bound::Included(value) => util::split(*value), - // Bound::Excluded(value) => util::split(*value), - // Bound::Unbounded => util::split(0), - // }; - - // let end = match range.end_bound() { - // Bound::Included(value) => util::split(*value), - // Bound::Excluded(value) => util::split(*value), - // Bound::Unbounded => util::split(u32::max_value()), - // }; - - // dbg!(start, end); - - // ... - - todo!() - - // let (start_container_key, start_index) = match range.start_bound() { - // Bound::Included(value) => util::split(*value), - // Bound::Excluded(value) if *value == u32::max_value() => return 0, - // Bound::Excluded(value) => util::split(value + 1), - // Bound::Unbounded => util::split(0), - // }; - - // let (end_container_key, end_index) = match range.end_bound() { - // Bound::Included(value) => util::split(*value), - // Bound::Excluded(value) if *value == 0 => return 0, - // Bound::Excluded(value) => util::split(value - 1), - // Bound::Unbounded => util::split(u32::max_value()), - // }; - - // dbg!(range.start_bound(), range.end_bound()); - // dbg!(start_container_key, start_index, end_container_key, end_index); - - // // Find the container index for start_container_key - // let start_i = match self - // .containers - // .binary_search_by_key(&start_container_key, |c| c.key) - // { - // Ok(loc) => loc, - // Err(loc) => { - // self.containers - // .insert(loc, Container::new(start_container_key)); - // loc - // } - // }; - - // // If the end range value is in the same container, just call into - // // the one container. - // if start_container_key == end_container_key { - // dbg!(start_container_key, end_container_key); - // return self.containers[start_i].insert_range(start_index..end_index); - // } - - // // For the first container, insert start_index..u16::MAX, with - // // subsequent containers inserting 0..MAX. - // // - // // The last container (end_container_key) is handled explicitly outside - // // the loop. - // let mut low = start_index; - // let mut inserted = 0; - - // // Walk through the containers until the container for end_container_key - // let end_i = usize::from(end_container_key - start_container_key); - // for i in start_i..end_i { - // // Fetch (or upsert) the container for i - // let c = match self.containers.get_mut(i) { - // Some(c) => c, - // None => { - // // For each i, the container key is start_container + i in - // // the upper u8 of the u16. - // let key = start_container_key + ((1 << 8) * i) as u16; - // self.containers.insert(i, Container::new(key)); - // &mut self.containers[i] - // } - // }; - - // // Insert the range subset for this container - // inserted += c.insert_range(low..u16::MAX); - - // // After the first container, always fill the containers. - // low = 0; - // } - - // // Handle the last container - // let c = match self.containers.get_mut(end_i) { - // Some(c) => c, - // None => { - // self.containers.insert(end_i, Container::new(start_container_key)); - // &mut self.containers[end_i] - // } - // }; - // c.insert_range(0..end_index); - - // inserted + inserted += self.containers[last_index].insert_range(0..=end_index); + + inserted } /// Adds a value to the set. @@ -240,7 +190,7 @@ impl RoaringBitmap { } } - /// Removes a range of values from the set. + /// Removes a range of values. /// Returns the number of removed values. /// /// # Examples @@ -253,62 +203,42 @@ impl RoaringBitmap { /// rb.insert(3); /// assert_eq!(rb.remove_range(2..4), 2); /// ``` - pub fn remove_range>(&mut self, range: R) -> u64 { - let start = match range.start_bound() { - Bound::Included(value) => *value, - Bound::Excluded(value) => match value.checked_add(1) { - Some(value) => value, - None => return 0, - }, - Bound::Unbounded => 0, + pub fn remove_range(&mut self, range: R) -> u64 + where + R: RangeBounds, + { + let (start, end) = match util::convert_range_to_inclusive(range) { + Some(range) => (*range.start(), *range.end()), + None => return 0, }; - let end = match range.end_bound() { - Bound::Included(value) => match value.checked_add(1) { - Some(value) => value, - None => return 0, - }, - Bound::Excluded(value) => *value, - Bound::Unbounded => u32::max_value(), - }; + let (start_container_key, start_index) = util::split(start); + let (end_container_key, end_index) = util::split(end); - if end.saturating_sub(start) == 0 { - return 0; - } - // inclusive bounds for start and end - let (start_hi, start_lo) = util::split(start); - let (end_hi, end_lo) = util::split(end - 1); let mut index = 0; - let mut result = 0; + let mut removed = 0; while index < self.containers.len() { let key = self.containers[index].key; - if key >= start_hi && key <= end_hi { - let a = if key == start_hi { - u32::from(start_lo) + if key >= start_container_key && key <= end_container_key { + let a = if key == start_container_key { + start_index } else { 0 }; - let b = if key == end_hi { - u32::from(end_lo) + 1 // make it exclusive + let b = if key == end_container_key { + end_index } else { - u32::from(u16::max_value()) + 1 + u16::max_value() }; - // remove container? - if a == 0 && b == u32::from(u16::max_value()) + 1 { - result += self.containers[index].len; + removed += self.containers[index].remove_range(a..=b); + if self.containers[index].len == 0 { self.containers.remove(index); continue; - } else { - result += self.containers[index].remove_range(a, b); - if self.containers[index].len == 0 { - self.containers.remove(index); - continue; - } } } index += 1; } - result + removed } /// Returns `true` if this set contains the specified integer. @@ -447,9 +377,9 @@ mod tests { let mut b = RoaringBitmap::new(); let inserted = b.insert_range(r.clone()); if r.end > r.start { - assert_eq!(inserted as u32, r.end - r.start); + assert_eq!(inserted, r.end as u64 - r.start as u64); } else { - assert_eq!(inserted as u32, 0); + assert_eq!(inserted, 0); } // Assert all values in the range are present @@ -461,16 +391,16 @@ mod tests { for i in checks { let bitmap_has = b.contains(i); let range_has = r.contains(&i); - assert!( - bitmap_has == range_has, + assert_eq!( + bitmap_has, range_has, "value {} in bitmap={} and range={}", - i, bitmap_has, range_has, + i, bitmap_has, range_has ); } } #[test] - fn test_insert_range_same_container() { + fn test_insert_remove_range_same_container() { let mut b = RoaringBitmap::new(); let inserted = b.insert_range(1..5); assert_eq!(inserted, 4); @@ -478,48 +408,96 @@ mod tests { for i in 1..5 { assert!(b.contains(i)); } + + let removed = b.remove_range(2..10); + assert_eq!(removed, 3); + assert!(b.contains(1)); + for i in 2..5 { + assert!(!b.contains(i)); + } } #[test] - fn test_insert_range_pre_populated() { + fn test_insert_remove_range_pre_populated() { let mut b = RoaringBitmap::new(); let inserted = b.insert_range(1..20_000); assert_eq!(inserted, 19_999); + let removed = b.remove_range(10_000..21_000); + assert_eq!(removed, 10_000); + let inserted = b.insert_range(1..20_000); - assert_eq!(inserted, 0); + assert_eq!(inserted, 10_000); } #[test] fn test_insert_max_u32() { let mut b = RoaringBitmap::new(); let inserted = b.insert(u32::MAX); - // We are allowed to add u32::MAX + // We are allowed to add u32::MAX assert!(inserted); } #[test] - fn test_insert_range_zero_inclusive() { + fn test_insert_remove_across_container() { let mut b = RoaringBitmap::new(); - let inserted = b.insert_range(0..=0); - // `insert_range(value..=value)` appears equivalent to `insert(value)` - assert_eq!(inserted, 1); - assert!(b.contains(0), "does not contain {}", 0); + let inserted = b.insert_range(u16::MAX as u32..=u16::MAX as u32 + 1); + assert_eq!(inserted, 2); + + assert_eq!(b.containers.len(), 2); + + let removed = b.remove_range(u16::MAX as u32 + 1..=u16::MAX as u32 + 1); + assert_eq!(removed, 1); + + assert_eq!(b.containers.len(), 1); } #[test] - fn test_insert_range_max_u32_inclusive() { + fn test_insert_remove_single_element() { let mut b = RoaringBitmap::new(); - let inserted = b.insert_range(u32::MAX..=u32::MAX); - // But not equivalent for u32::MAX - assert_eq!(inserted, 1); // Fails - left: 0, right: 1 - assert!(b.contains(u32::MAX), "does not contain {}", u32::MAX); + let inserted = b.insert_range(u16::MAX as u32 + 1..=u16::MAX as u32 + 1); + assert_eq!(inserted, 1); + + assert_eq!(b.containers[0].len, 1); + assert_eq!(b.containers.len(), 1); + + let removed = b.remove_range(u16::MAX as u32 + 1..=u16::MAX as u32 + 1); + assert_eq!(removed, 1); + + assert_eq!(b.containers.len(), 0); } #[test] - fn test_insert_all_u32() { - let mut b = RoaringBitmap::new(); - let inserted = b.insert_range(0..u32::MAX); // Largest possible range seemingly allowed - assert_eq!(inserted, u32::MAX as u64); // Still not bigger than u32::MAX + fn test_insert_remove_range_multi_container() { + let mut bitmap = RoaringBitmap::new(); + assert_eq!( + bitmap.insert_range(0..((1_u32 << 16) + 1)), + (1_u64 << 16) + 1 + ); + assert_eq!(bitmap.containers.len(), 2); + assert_eq!(bitmap.containers[0].key, 0); + assert_eq!(bitmap.containers[1].key, 1); + assert_eq!(bitmap.insert_range(0..((1_u32 << 16) + 1)), 0); + + assert!(bitmap.insert((1_u32 << 16) * 4)); + assert_eq!(bitmap.containers.len(), 3); + assert_eq!(bitmap.containers[2].key, 4); + + assert_eq!( + bitmap.remove_range(((1_u32 << 16) * 3)..=((1_u32 << 16) * 4)), + 1 + ); + assert_eq!(bitmap.containers.len(), 2); + } + + #[test] + fn insert_range_single() { + let mut bitmap = RoaringBitmap::new(); + assert_eq!( + bitmap.insert_range((1_u32 << 16)..(2_u32 << 16)), + 1_u64 << 16 + ); + assert_eq!(bitmap.containers.len(), 1); + assert_eq!(bitmap.containers[0].key, 1); } } diff --git a/src/bitmap/store.rs b/src/bitmap/store.rs index dc46ae14..fd783104 100644 --- a/src/bitmap/store.rs +++ b/src/bitmap/store.rs @@ -1,11 +1,12 @@ +use std::{slice, vec}; +use std::borrow::Borrow; use std::cmp::Ordering::{Equal, Greater, Less}; -use std::slice; -use std::vec; -use std::{borrow::Borrow, ops::Range}; +use std::ops::RangeInclusive; + +use self::Store::{Array, Bitmap}; const BITMAP_LENGTH: usize = 1024; -use self::Store::{Array, Bitmap}; pub enum Store { Array(Vec), Bitmap(Box<[u64; BITMAP_LENGTH]>), @@ -43,40 +44,55 @@ impl Store { } } - pub fn insert_range(&mut self, range: Range) -> u64 { + pub fn insert_range(&mut self, range: RangeInclusive) -> u64 { // A Range is defined as being of size 0 if start >= end. if range.is_empty() { return 0; } + let start = *range.start(); + let end = *range.end(); + match *self { Array(ref mut vec) => { - // Figure out the starting/ending position in the vec - let pos_start = vec.binary_search(&range.start).unwrap_or_else(|x| x); - let pos_end = vec.binary_search(&range.end).unwrap_or_else(|x| x); + // Figure out the starting/ending position in the vec. + let pos_start = vec.binary_search(&start).unwrap_or_else(|x| x); + let pos_end = vec + .binary_search_by(|p| { + // binary search the right most position when equals + match p.cmp(&end) { + Greater => Greater, + _ => Less, + } + }) + .unwrap_or_else(|x| x); // Overwrite the range in the middle - there's no need to take // into account any existing elements between start and end, as // they're all being added to the set. - let dropped = vec.splice(pos_start..pos_end, range.clone()); + let dropped = vec.splice(pos_start..pos_end, start..=end); - u64::from(range.end - range.start) - dropped.len() as u64 + end as u64 - start as u64 + 1 - dropped.len() as u64 } Bitmap(ref mut bits) => { - let (start_key, start_bit) = (key(range.start), bit(range.start)); - let (end_key, end_bit) = (key(range.end), bit(range.end)); + let (start_key, start_bit) = (key(start), bit(start)); + let (end_key, end_bit) = (key(end), bit(end)); + // MSB > start_bit > end_bit > LSB if start_key == end_key { // Set the end_bit -> LSB to 1 - let mut mask = (1 << end_bit) - 1; - // Set start_bit -> LSB to 0 + let mut mask = if end_bit == 63 { + u64::MAX + } else { + (1 << (end_bit + 1)) - 1 + }; + // Set MSB -> start_bit to 1 mask &= !((1 << start_bit) - 1); - // Leaving end_bit -> start_bit set to 1 let existed = (bits[start_key] & mask).count_ones(); bits[start_key] |= mask; - return u64::from(range.end - range.start) - u64::from(existed); + return u64::from(end - start + 1) - u64::from(existed); } // Mask off the left-most bits (MSB -> start_bit) @@ -95,11 +111,15 @@ impl Store { } // Set the end bits in the last chunk (MSB -> end_bit) - let mask = (1 << end_bit) - 1; + let mask = if end_bit == 63 { + u64::MAX + } else { + (1 << (end_bit + 1)) - 1 + }; existed += (bits[end_key] & mask).count_ones(); bits[end_key] |= mask; - u64::from(range.end - range.start) - u64::from(existed) + end as u64 - start as u64 + 1 - existed as u64 } } } @@ -146,28 +166,36 @@ impl Store { } } - pub fn remove_range(&mut self, start: u32, end: u32) -> u64 { - debug_assert!(start < end, "caller must ensure start < end"); + pub fn remove_range(&mut self, range: RangeInclusive) -> u64 { + if range.is_empty() { + return 0; + } + + let start = *range.start(); + let end = *range.end(); + match *self { Array(ref mut vec) => { - let a = vec.binary_search(&(start as u16)).unwrap_or_else(|e| e); - let b = if end > u32::from(u16::max_value()) { - vec.len() - } else { - vec.binary_search(&(end as u16)).unwrap_or_else(|e| e) - }; - vec.drain(a..b); - (b - a) as u64 + // Figure out the starting/ending position in the vec. + let pos_start = vec.binary_search(&start).unwrap_or_else(|x| x); + let pos_end = vec + .binary_search_by(|p| { + // binary search the right most position when equals + match p.cmp(&end) { + Greater => Greater, + _ => Less, + } + }) + .unwrap_or_else(|x| x); + vec.drain(pos_start..pos_end); + (pos_end - pos_start) as u64 } Bitmap(ref mut bits) => { - let start_key = key(start as u16) as usize; - let start_bit = bit(start as u16) as u32; - // end_key is inclusive - let end_key = key((end - 1) as u16) as usize; - let end_bit = bit(end as u16) as u32; + let (start_key, start_bit) = (key(start), bit(start)); + let (end_key, end_bit) = (key(end), bit(end)); if start_key == end_key { - let mask = (!0u64 << start_bit) & (!0u64).wrapping_shr(64 - end_bit); + let mask = (!0u64 << start_bit) & (!0u64 >> (63 - end_bit)); let removed = (bits[start_key] & mask).count_ones(); bits[start_key] &= !mask; return u64::from(removed); @@ -189,8 +217,8 @@ impl Store { *word = 0; } // end key bits - removed += (bits[end_key] & (!0u64).wrapping_shr(64 - end_bit)).count_ones(); - bits[end_key] &= !(!0u64).wrapping_shr(64 - end_bit); + removed += (bits[end_key] & (!0u64 >> (63 - end_bit))).count_ones(); + bits[end_key] &= !(!0u64 >> (63 - end_bit)); u64::from(removed) } } @@ -612,7 +640,7 @@ mod tests { let mut store = Store::Array(vec![1, 2, 8, 9]); // Insert a range with start > end. - let new = store.insert_range(6..1); + let new = store.insert_range(6..=1); assert_eq!(new, 0); assert_eq!(as_vec(store), vec![1, 2, 8, 9]); @@ -622,7 +650,7 @@ mod tests { fn test_array_insert_range() { let mut store = Store::Array(vec![1, 2, 8, 9]); - let new = store.insert_range(4..6); + let new = store.insert_range(4..=5); assert_eq!(new, 2); assert_eq!(as_vec(store), vec![1, 2, 4, 5, 8, 9]); @@ -632,7 +660,7 @@ mod tests { fn test_array_insert_range_left_overlap() { let mut store = Store::Array(vec![1, 2, 8, 9]); - let new = store.insert_range(2..6); + let new = store.insert_range(2..=5); assert_eq!(new, 3); assert_eq!(as_vec(store), vec![1, 2, 3, 4, 5, 8, 9]); @@ -642,7 +670,7 @@ mod tests { fn test_array_insert_range_right_overlap() { let mut store = Store::Array(vec![1, 2, 8, 9]); - let new = store.insert_range(4..9); + let new = store.insert_range(4..=8); assert_eq!(new, 4); assert_eq!(as_vec(store), vec![1, 2, 4, 5, 6, 7, 8, 9]); @@ -652,7 +680,7 @@ mod tests { fn test_array_insert_range_full_overlap() { let mut store = Store::Array(vec![1, 2, 8, 9]); - let new = store.insert_range(1..10); + let new = store.insert_range(1..=9); assert_eq!(new, 5); assert_eq!(as_vec(store), vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); @@ -665,7 +693,7 @@ mod tests { let mut store = store.to_bitmap(); // Insert a range with start > end. - let new = store.insert_range(6..1); + let new = store.insert_range(6..=1); assert_eq!(new, 0); assert_eq!(as_vec(store), vec![1, 2, 8, 9]); @@ -676,7 +704,7 @@ mod tests { let store = Store::Array(vec![1, 2, 3, 62, 63]); let mut store = store.to_bitmap(); - let new = store.insert_range(1..63); + let new = store.insert_range(1..=62); assert_eq!(new, 58); assert_eq!(as_vec(store), (1..64).collect::>()); @@ -687,7 +715,7 @@ mod tests { let store = Store::Array(vec![1, 2, 130]); let mut store = store.to_bitmap(); - let new = store.insert_range(4..129); + let new = store.insert_range(4..=128); assert_eq!(new, 125); let mut want = vec![1, 2]; @@ -702,7 +730,7 @@ mod tests { let store = Store::Array(vec![1, 2, 130]); let mut store = store.to_bitmap(); - let new = store.insert_range(1..129); + let new = store.insert_range(1..=128); assert_eq!(new, 126); let mut want = Vec::new(); @@ -717,7 +745,7 @@ mod tests { let store = Store::Array(vec![1, 2, 130]); let mut store = store.to_bitmap(); - let new = store.insert_range(4..133); + let new = store.insert_range(4..=132); assert_eq!(new, 128); let mut want = vec![1, 2]; @@ -731,7 +759,7 @@ mod tests { let store = Store::Array(vec![1, 2, 130]); let mut store = store.to_bitmap(); - let new = store.insert_range(1..135); + let new = store.insert_range(1..=134); assert_eq!(new, 131); let mut want = Vec::new(); diff --git a/src/bitmap/util.rs b/src/bitmap/util.rs index 43d0bc70..84a2ac24 100644 --- a/src/bitmap/util.rs +++ b/src/bitmap/util.rs @@ -1,3 +1,5 @@ +use std::ops::{Bound, RangeBounds, RangeInclusive}; + /// Returns the container key and the index /// in this container for a given integer. #[inline] @@ -12,9 +14,32 @@ pub fn join(high: u16, low: u16) -> u32 { (u32::from(high) << 16) + u32::from(low) } +/// Convert a `RangeBounds` object to `RangeInclusive`, +pub fn convert_range_to_inclusive(range: R) -> Option> +where + R: RangeBounds, +{ + let start: u32 = match range.start_bound() { + Bound::Included(&i) => i, + Bound::Excluded(&u32::MAX) => return None, + Bound::Excluded(&i) => i + 1, + Bound::Unbounded => 0, + }; + let end: u32 = match range.end_bound() { + Bound::Included(&i) => i, + Bound::Excluded(&0) => return None, + Bound::Excluded(&i) => i - 1, + Bound::Unbounded => u32::MAX, + }; + if end < start { + return None; + } + Some(start..=end) +} + #[cfg(test)] mod test { - use super::{join, split}; + use super::{join, split, convert_range_to_inclusive}; #[test] fn test_split_u32() { @@ -39,4 +64,13 @@ mod test { assert_eq!(0xFFFF_FFFEu32, join(0xFFFFu16, 0xFFFEu16)); assert_eq!(0xFFFF_FFFFu32, join(0xFFFFu16, 0xFFFFu16)); } + + #[test] + fn test_convert_range_to_inclusive() { + assert_eq!(Some(1..=5), convert_range_to_inclusive(1..6)); + assert_eq!(Some(1..=u32::MAX), convert_range_to_inclusive(1..)); + assert_eq!(Some(0..=u32::MAX), convert_range_to_inclusive(..)); + assert_eq!(None, convert_range_to_inclusive(5..5)); + assert_eq!(Some(16..=16), convert_range_to_inclusive(16..=16)) + } } diff --git a/src/treemap/inherent.rs b/src/treemap/inherent.rs index a3c349c7..4f337f05 100644 --- a/src/treemap/inherent.rs +++ b/src/treemap/inherent.rs @@ -1,5 +1,5 @@ use std::collections::btree_map::{BTreeMap, Entry}; -use std::ops::{Range, RangeInclusive}; +use std::ops::RangeBounds; use crate::RoaringBitmap; use crate::RoaringTreemap; @@ -96,7 +96,7 @@ impl RoaringTreemap { } } - /// Removes a range of values from the set. + /// Removes a range of values. /// Returns the number of removed values. /// /// # Examples @@ -109,33 +109,36 @@ impl RoaringTreemap { /// rb.insert(3); /// assert_eq!(rb.remove_range(2..4), 2); /// ``` - pub fn remove_range(&mut self, range: Range) -> u64 { - if range.start == range.end { - return 0; - } + pub fn remove_range(&mut self, range: R) -> u64 + where + R: RangeBounds, + { + let (start, end) = match util::convert_range_to_inclusive(range) { + Some(range) => (*range.start(), *range.end()), + None => return 0, + }; + + let (start_container_key, start_index) = util::split(start); + let (end_container_key, end_index) = util::split(end); + let mut keys_to_remove = Vec::new(); let mut removed = 0; - // inclusive bounds for start and end - let (start_hi, start_lo) = util::split(range.start); - let (end_hi, end_lo) = util::split(range.end - 1); + for (&key, rb) in &mut self.map { - if key >= start_hi && key <= end_hi { - let start = if key == start_hi { start_lo } else { 0 }; - let end = if key == end_hi { - end_lo + if key >= start_container_key && key <= end_container_key { + let a = if key == start_container_key { + start_index + } else { + 0 + }; + let b = if key == end_container_key { + end_index } else { u32::max_value() }; - let range = RangeInclusive::new(start, end); - - if key != start_hi && key != end_lo { - removed += rb.len(); + removed += rb.remove_range(a..=b); + if rb.is_empty() { keys_to_remove.push(key); - } else { - removed += rb.remove_range(range); - if rb.is_empty() { - keys_to_remove.push(key); - } } } } diff --git a/src/treemap/util.rs b/src/treemap/util.rs index c193cbac..40374d54 100644 --- a/src/treemap/util.rs +++ b/src/treemap/util.rs @@ -1,3 +1,5 @@ +use std::ops::{Bound, RangeBounds, RangeInclusive}; + #[inline] pub fn split(value: u64) -> (u32, u32) { ((value >> 32) as u32, value as u32) @@ -8,9 +10,32 @@ pub fn join(high: u32, low: u32) -> u64 { (u64::from(high) << 32) | u64::from(low) } +/// Convert a `RangeBounds` object to `RangeInclusive`, +pub fn convert_range_to_inclusive(range: R) -> Option> +where + R: RangeBounds, +{ + let start: u64 = match range.start_bound() { + Bound::Included(&i) => i, + Bound::Excluded(&u64::MAX) => return None, + Bound::Excluded(&i) => i + 1, + Bound::Unbounded => 0, + }; + let end: u64 = match range.end_bound() { + Bound::Included(&i) => i, + Bound::Excluded(&0) => return None, + Bound::Excluded(&i) => i - 1, + Bound::Unbounded => u64::MAX, + }; + if end < start { + return None; + } + Some(start..=end) +} + #[cfg(test)] mod test { - use super::{join, split}; + use super::{join, split, convert_range_to_inclusive}; #[test] fn test_split_u64() { @@ -83,4 +108,13 @@ mod test { join(0xFFFF_FFFFu32, 0xFFFF_FFFFu32) ); } + + #[test] + fn test_convert_range_to_inclusive() { + assert_eq!(Some(1..=5), convert_range_to_inclusive(1..6)); + assert_eq!(Some(1..=u64::MAX), convert_range_to_inclusive(1..)); + assert_eq!(Some(0..=u64::MAX), convert_range_to_inclusive(..)); + assert_eq!(None, convert_range_to_inclusive(5..5)); + assert_eq!(Some(16..=16), convert_range_to_inclusive(16..=16)) + } }