Skip to content

Commit

Permalink
auto merge of rust-lang#17558 : kaseyc/rust/fix_bitvset_union, r=aturon
Browse files Browse the repository at this point in the history
Updates the other_op function shared by the union/intersect/difference/symmetric_difference -with functions to fix an issue where certain elements would not be present in the result. To fix this, when other op is called, we resize self's nbits to account for any new elements that may be added to the set.

Example:
```rust
	let mut a = BitvSet::new();
	let mut b = BitvSet::new();
	a.insert(0);
	b.insert(5);
	a.union_with(&b);
	println!("{}", a); //Prints "{0}" instead of "{0, 5}"
```
  • Loading branch information
bors committed Oct 9, 2014
2 parents eb04229 + dd4fa90 commit 79d056f
Showing 1 changed file with 131 additions and 57 deletions.
188 changes: 131 additions & 57 deletions src/libcollections/bitv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,20 @@ impl Bitv {
/// }
/// ```
pub fn with_capacity(nbits: uint, init: bool) -> Bitv {
Bitv {
let mut bitv = Bitv {
storage: Vec::from_elem((nbits + uint::BITS - 1) / uint::BITS,
if init { !0u } else { 0u }),
nbits: nbits
};

// Zero out any unused bits in the highest word if necessary
let used_bits = bitv.nbits % uint::BITS;
if init && used_bits != 0 {
let largest_used_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(largest_used_word) &= (1 << used_bits) - 1;
}

bitv
}

/// Retrieves the value at index `i`.
Expand Down Expand Up @@ -629,9 +638,9 @@ impl Bitv {
/// ```
pub fn reserve(&mut self, size: uint) {
let old_size = self.storage.len();
let size = (size + uint::BITS - 1) / uint::BITS;
if old_size < size {
self.storage.grow(size - old_size, 0);
let new_size = (size + uint::BITS - 1) / uint::BITS;
if old_size < new_size {
self.storage.grow(new_size - old_size, 0);
}
}

Expand Down Expand Up @@ -686,8 +695,15 @@ impl Bitv {
}
// Allocate new words, if needed
if new_nwords > self.storage.len() {
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);
let to_add = new_nwords - self.storage.len();
self.storage.grow(to_add, full_value);

// Zero out and unused bits in the new tail word
if value {
let tail_word = new_nwords - 1;
let used_bits = new_nbits % uint::BITS;
*self.storage.get_mut(tail_word) &= (1 << used_bits) - 1;
}
}
// Adjust internal bit count
self.nbits = new_nbits;
Expand Down Expand Up @@ -970,9 +986,8 @@ impl<'a> RandomAccessIterator<bool> for Bits<'a> {
/// }
///
/// // Can convert back to a `Bitv`
/// let bv: Bitv = s.unwrap();
/// assert!(bv.eq_vec([true, true, false, true,
/// false, false, false, false]));
/// let bv: Bitv = s.into_bitv();
/// assert!(bv.get(3));
/// ```
#[deriving(Clone)]
pub struct BitvSet(Bitv);
Expand All @@ -993,7 +1008,8 @@ impl FromIterator<bool> for BitvSet {
impl Extendable<bool> for BitvSet {
#[inline]
fn extend<I: Iterator<bool>>(&mut self, iterator: I) {
self.get_mut_ref().extend(iterator);
let &BitvSet(ref mut self_bitv) = self;
self_bitv.extend(iterator);
}
}

Expand Down Expand Up @@ -1049,7 +1065,8 @@ impl BitvSet {
/// ```
#[inline]
pub fn with_capacity(nbits: uint) -> BitvSet {
BitvSet(Bitv::with_capacity(nbits, false))
let bitv = Bitv::with_capacity(nbits, false);
BitvSet::from_bitv(bitv)
}

/// Creates a new bit vector set from the given bit vector.
Expand All @@ -1068,7 +1085,9 @@ impl BitvSet {
/// }
/// ```
#[inline]
pub fn from_bitv(bitv: Bitv) -> BitvSet {
pub fn from_bitv(mut bitv: Bitv) -> BitvSet {
// Mark every bit as valid
bitv.nbits = bitv.capacity();
BitvSet(bitv)
}

Expand Down Expand Up @@ -1102,7 +1121,10 @@ impl BitvSet {
/// ```
pub fn reserve(&mut self, size: uint) {
let &BitvSet(ref mut bitv) = self;
bitv.reserve(size)
bitv.reserve(size);
if bitv.nbits < size {
bitv.nbits = bitv.capacity();
}
}

/// Consumes this set to return the underlying bit vector.
Expand All @@ -1116,11 +1138,12 @@ impl BitvSet {
/// s.insert(0);
/// s.insert(3);
///
/// let bv = s.unwrap();
/// assert!(bv.eq_vec([true, false, false, true]));
/// let bv = s.into_bitv();
/// assert!(bv.get(0));
/// assert!(bv.get(3));
/// ```
#[inline]
pub fn unwrap(self) -> Bitv {
pub fn into_bitv(self) -> Bitv {
let BitvSet(bitv) = self;
bitv
}
Expand All @@ -1144,38 +1167,15 @@ impl BitvSet {
bitv
}

/// Returns a mutable reference to the underlying bit vector.
///
/// # Example
///
/// ```
/// use std::collections::BitvSet;
///
/// let mut s = BitvSet::new();
/// s.insert(0);
/// assert_eq!(s.contains(&0), true);
/// {
/// // Will free the set during bv's lifetime
/// let bv = s.get_mut_ref();
/// bv.set(0, false);
/// }
/// assert_eq!(s.contains(&0), false);
/// ```
#[inline]
pub fn get_mut_ref<'a>(&'a mut self) -> &'a mut Bitv {
let &BitvSet(ref mut bitv) = self;
bitv
}

#[inline]
fn other_op(&mut self, other: &BitvSet, f: |uint, uint| -> uint) {
// Expand the vector if necessary
self.reserve(other.capacity());

// Unwrap Bitvs
let &BitvSet(ref mut self_bitv) = self;
let &BitvSet(ref other_bitv) = other;

// Expand the vector if necessary
self_bitv.reserve(other_bitv.capacity());

// virtually pad other with 0's for equal lengths
let mut other_words = {
let (_, result) = match_words(self_bitv, other_bitv);
Expand Down Expand Up @@ -1376,9 +1376,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.union_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn union_with(&mut self, other: &BitvSet) {
Expand All @@ -1399,9 +1400,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.intersect_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn intersect_with(&mut self, other: &BitvSet) {
Expand All @@ -1424,15 +1426,17 @@ impl BitvSet {
///
/// let mut bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let bva_b = BitvSet::from_bitv(bitv::from_bytes([a_b]));
/// let bvb_a = BitvSet::from_bitv(bitv::from_bytes([b_a]));
///
/// bva.difference_with(&bvb);
/// assert_eq!(bva.unwrap(), bitv::from_bytes([a_b]));
/// assert_eq!(bva, bva_b);
///
/// let bva = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let mut bvb = BitvSet::from_bitv(bitv::from_bytes([b]));
///
/// bvb.difference_with(&bva);
/// assert_eq!(bvb.unwrap(), bitv::from_bytes([b_a]));
/// assert_eq!(bvb, bvb_a);
/// ```
#[inline]
pub fn difference_with(&mut self, other: &BitvSet) {
Expand All @@ -1454,9 +1458,10 @@ impl BitvSet {
///
/// let mut a = BitvSet::from_bitv(bitv::from_bytes([a]));
/// let b = BitvSet::from_bitv(bitv::from_bytes([b]));
/// let res = BitvSet::from_bitv(bitv::from_bytes([res]));
///
/// a.symmetric_difference_with(&b);
/// assert_eq!(a.unwrap(), bitv::from_bytes([res]));
/// assert_eq!(a, res);
/// ```
#[inline]
pub fn symmetric_difference_with(&mut self, other: &BitvSet) {
Expand Down Expand Up @@ -1538,20 +1543,14 @@ impl MutableSet<uint> for BitvSet {
if self.contains(&value) {
return false;
}

// Ensure we have enough space to hold the new element
if value >= self.capacity() {
let new_cap = cmp::max(value + 1, self.capacity() * 2);
self.reserve(new_cap);
}

let &BitvSet(ref mut bitv) = self;
if value >= bitv.nbits {
// If we are increasing nbits, make sure we mask out any previously-unconsidered bits
let old_rem = bitv.nbits % uint::BITS;
if old_rem != 0 {
let old_last_word = (bitv.nbits + uint::BITS - 1) / uint::BITS - 1;
*bitv.storage.get_mut(old_last_word) &= (1 << old_rem) - 1;
}
bitv.nbits = value + 1;
}
bitv.set(value, true);
return true;
}
Expand Down Expand Up @@ -2225,14 +2224,15 @@ mod tests {
assert!(a.insert(160));
assert!(a.insert(19));
assert!(a.insert(24));
assert!(a.insert(200));

assert!(b.insert(1));
assert!(b.insert(5));
assert!(b.insert(9));
assert!(b.insert(13));
assert!(b.insert(19));

let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160];
let expected = [1, 3, 5, 9, 11, 13, 19, 24, 160, 200];
let actual = a.union(&b).collect::<Vec<uint>>();
assert_eq!(actual.as_slice(), expected.as_slice());
}
Expand Down Expand Up @@ -2281,6 +2281,27 @@ mod tests {
assert!(c.is_disjoint(&b))
}

#[test]
fn test_bitv_set_union_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
let mut b = BitvSet::new();
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.union_with(&b);
assert_eq!(a, expected);

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.union_with(&b);
b.union_with(&c);
assert_eq!(a.len(), 4);
assert_eq!(b.len(), 4);
}

#[test]
fn test_bitv_set_intersect_with() {
// Explicitly 0'ed bits
Expand Down Expand Up @@ -2311,6 +2332,59 @@ mod tests {
assert_eq!(b.len(), 2);
}

#[test]
fn test_bitv_set_difference_with() {
// Explicitly 0'ed bits
let mut a = BitvSet::from_bitv(from_bytes([0b00000000]));
let b = BitvSet::from_bitv(from_bytes([0b10100010]));
a.difference_with(&b);
assert!(a.is_empty());

// Uninitialized bits should behave like 0's
let mut a = BitvSet::new();
let b = BitvSet::from_bitv(from_bytes([0b11111111]));
a.difference_with(&b);
assert!(a.is_empty());

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01100010]));
let c = a.clone();
a.difference_with(&b);
b.difference_with(&c);
assert_eq!(a.len(), 1);
assert_eq!(b.len(), 1);
}

#[test]
fn test_bitv_set_symmetric_difference_with() {
//a should grow to include larger elements
let mut a = BitvSet::new();
a.insert(0);
a.insert(1);
let mut b = BitvSet::new();
b.insert(1);
b.insert(5);
let expected = BitvSet::from_bitv(from_bytes([0b10000100]));
a.symmetric_difference_with(&b);
assert_eq!(a, expected);

let mut a = BitvSet::from_bitv(from_bytes([0b10100010]));
let b = BitvSet::new();
let c = a.clone();
a.symmetric_difference_with(&b);
assert_eq!(a, c);

// Standard
let mut a = BitvSet::from_bitv(from_bytes([0b11100010]));
let mut b = BitvSet::from_bitv(from_bytes([0b01101010]));
let c = a.clone();
a.symmetric_difference_with(&b);
b.symmetric_difference_with(&c);
assert_eq!(a.len(), 2);
assert_eq!(b.len(), 2);
}

#[test]
fn test_bitv_set_eq() {
let a = BitvSet::from_bitv(from_bytes([0b10100010]));
Expand Down

0 comments on commit 79d056f

Please sign in to comment.