diff --git a/.gitignore b/.gitignore index a9d37c5..4a1e9db 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.idea/ diff --git a/Cargo.toml b/Cargo.toml index d2a231a..7431145 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,4 @@ documentation = "https://cfallin.github.io/rust-smallset/smallset/" license = "MIT" [dependencies] -smallvec = "0.1" +smallvec = "1.4.2" diff --git a/src/lib.rs b/src/lib.rs index 2cbcf20..d36f9ba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,14 @@ // Copyright (c) 2016 Chris Fallin . Released under the MIT license. // +extern crate smallvec; + use std::fmt; use std::iter::{FromIterator, IntoIterator}; -use std::slice::Iter; -extern crate smallvec; use smallvec::{Array, SmallVec}; +use std::collections::HashSet; +use std::hash::Hash; /// A `SmallSet` is an unordered set of elements. It is designed to work best /// for very small sets (no more than ten or so elements). In order to support @@ -33,70 +35,456 @@ use smallvec::{Array, SmallVec}; /// s.insert(1); /// s.insert(2); /// s.insert(3); -/// assert!(s.len() == 3); +/// assert_eq!(s.len(), 3); /// assert!(s.contains(&1)); /// ``` + pub struct SmallSet where A::Item: PartialEq + Eq, { - elements: SmallVec, + inner: InnerSmallVec, } -impl SmallSet +impl Default for SmallSet +where + A::Item: PartialEq + Eq + Hash, +{ + fn default() -> Self { + SmallSet::new() + } +} + +pub enum InnerSmallVec +where + A::Item: PartialEq + Eq, +{ + Stack(SmallVec), + Heap(std::collections::HashSet), +} + +impl Default for InnerSmallVec where A::Item: PartialEq + Eq, +{ + fn default() -> Self { + InnerSmallVec::Stack(SmallVec::new()) + } +} + +impl Clone for InnerSmallVec +where + A::Item: PartialEq + Eq + Clone, +{ + fn clone(&self) -> Self { + match &self { + InnerSmallVec::Stack(elements) => InnerSmallVec::Stack(elements.clone()), + InnerSmallVec::Heap(elements) => InnerSmallVec::Heap(elements.clone()), + } + } +} + +impl PartialEq for SmallSet +where + A::Item: Eq + PartialEq + Hash, +{ + fn eq(&self, other: &Self) -> bool { + fn set_same(stack: &SmallVec, heap: &HashSet) -> bool + where + A::Item: Eq + PartialEq, + { + stack.len() == heap.len() && heap.iter().all(|x| stack.contains(x)) + } + + match (&self.inner, &other.inner) { + (InnerSmallVec::Stack(lhs), InnerSmallVec::Stack(rhs)) => lhs.eq(rhs), + (InnerSmallVec::Heap(lhs), InnerSmallVec::Heap(rhs)) => lhs.eq(rhs), + (InnerSmallVec::Stack(stack), InnerSmallVec::Heap(heap)) => set_same(stack, heap), + (InnerSmallVec::Heap(heap), InnerSmallVec::Stack(stack)) => set_same(stack, heap), + } + } +} + +impl SmallSet +where + A::Item: PartialEq + Eq + Hash, { /// Creates a new, empty `SmallSet`. pub fn new() -> SmallSet { SmallSet { - elements: SmallVec::new(), + inner: InnerSmallVec::Stack(SmallVec::new()), } } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Inserts `elem` into the set if not yet present. Returns `true` if the /// set did not have this element present, or `false` if it already had this /// element present. pub fn insert(&mut self, elem: A::Item) -> bool { - if !self.contains(&elem) { - self.elements.push(elem); - true - } else { - false + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => { + if elements.contains(&elem) { + false + } else { + if elements.len() + 1 <= A::size() { + elements.push(elem); + } else { + let mut ee = HashSet::::with_capacity(elements.len() + 1); + while !elements.is_empty() { + ee.insert(elements.remove(0)); + } + ee.insert(elem); + self.inner = InnerSmallVec::Heap(ee); + } + true + } + } + InnerSmallVec::Heap(ref mut elements) => elements.insert(elem), } } /// Removes `elem` from the set. Returns `true` if the element was removed, /// or `false` if it was not found. pub fn remove(&mut self, elem: &A::Item) -> bool { - if let Some(pos) = self.elements.iter().position(|e| *e == *elem) { - self.elements.remove(pos); - true - } else { - false + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => { + if let Some(pos) = elements.iter().position(|e| *e == *elem) { + elements.remove(pos); + true + } else { + false + } + } + InnerSmallVec::Heap(ref mut elements) => elements.remove(elem), } } /// Tests whether `elem` is present. Returns `true` if it is present, or /// `false` if not. pub fn contains(&self, elem: &A::Item) -> bool { - self.elements.iter().any(|e| *e == *elem) + match &self.inner { + InnerSmallVec::Stack(ref elements) => elements.iter().any(|e| *e == *elem), + InnerSmallVec::Heap(ref elements) => elements.contains(elem), + } } /// Returns an iterator over the set elements. Elements will be returned in /// an arbitrary (unsorted) order. - pub fn iter(&self) -> Iter { - self.elements.iter() + pub fn iter(&self) -> SmallIter { + match &self.inner { + InnerSmallVec::Stack(element) => SmallIter { + inner: InnerSmallIter::Stack(element.iter()), + }, + InnerSmallVec::Heap(element) => SmallIter { + inner: InnerSmallIter::Heap(element.iter()), + }, + } } /// Returns the current length of the set. pub fn len(&self) -> usize { - self.elements.len() + match &self.inner { + InnerSmallVec::Stack(elements) => elements.len(), + InnerSmallVec::Heap(elements) => elements.len(), + } } /// Clears the set. pub fn clear(&mut self) { - self.elements.clear(); + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => elements.clear(), + InnerSmallVec::Heap(ref mut elements) => { + elements.clear(); + self.inner = Default::default(); + } + } + } + + // + pub fn get(&self, value: &A::Item) -> Option<&A::Item> { + match &self.inner { + InnerSmallVec::Stack(elements) => elements.iter().find(|x| (value).eq(&x)), + InnerSmallVec::Heap(elements) => elements.iter().find(|x| (value).eq(&x)), + } + } + + pub fn take(&mut self, value: &A::Item) -> Option { + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => { + if let Some(pos) = elements.iter().position(|e| *e == *value) { + let result = elements.remove(pos); + Some(result) + } else { + None + } + } + InnerSmallVec::Heap(ref mut elements) => elements.take(value), + } + } + + // Adds a value to the set, replacing the existing value, if any, that is equal to the given one. Returns the replaced value. + pub fn replace(&mut self, value: A::Item) -> Option { + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => { + if let Some(pos) = elements.iter().position(|e| *e == value) { + let result = elements.remove(pos); + elements.insert(pos, value); + Some(result) + } else { + None + } + } + InnerSmallVec::Heap(ref mut elements) => elements.replace(value), + } + } + + pub fn drain(&mut self) -> SmallDrain { + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => { + // TODO: Clean up this garbage... + let mut ee = Vec::::with_capacity(elements.len() + 1); + while !elements.is_empty() { + ee.push(elements.remove(0)); + } + SmallDrain { data: ee, index: 0 } + } + InnerSmallVec::Heap(ref mut elements) => { + let drain = elements.drain().collect::>(); + SmallDrain { + data: drain, + index: 0, + } + } + } + } + + pub fn retain(&mut self, f: F) + where + F: FnMut(&mut A::Item) -> bool + for<'r> FnMut(&'r ::Item) -> bool, + { + match &mut self.inner { + InnerSmallVec::Stack(ref mut elements) => elements.retain(f), + InnerSmallVec::Heap(ref mut elements) => elements.retain(f), + } + } + + pub fn intersection<'a>(&'a self, other: &'a Self) -> SmallIntersection<'a, A::Item> { + match &self.inner { + InnerSmallVec::Stack(ref elements) => { + let result = elements + .iter() + .filter(|x| other.contains(x)) + .collect::>(); + SmallIntersection { + data: result, + index: 0, + } + } + + InnerSmallVec::Heap(ref elements) => { + let result = elements + .iter() + .filter(|x| other.contains(x)) + .collect::>(); + SmallIntersection { + data: result, + index: 0, + } + } + } + } + + pub fn union<'a>(&'a self, other: &'a Self) -> SmallUnion<'a, A::Item> { + match &self.inner { + InnerSmallVec::Stack(ref elements) => { + let mut lhs = elements.iter().collect::>(); + let mut rhs = other + .iter() + .filter(|x| !lhs.contains(x)) + .collect::>(); + lhs.append(&mut rhs); + SmallUnion { + data: lhs, + index: 0, + } + } + + InnerSmallVec::Heap(ref elements) => { + let mut lhs = elements.iter().collect::>(); + let mut rhs = other + .iter() + .filter(|x| !lhs.contains(x)) + .collect::>(); + lhs.append(&mut rhs); + SmallUnion { + data: rhs, + index: 0, + } + } + } + } + + pub fn difference<'a>(&'a self, other: &'a Self) -> SmallDifference<'a, A::Item> { + match &self.inner { + InnerSmallVec::Stack(ref elements) => { + let lhs = elements + .iter() + .filter(|x| !other.contains(x)) + .collect::>(); + SmallDifference { + data: lhs, + index: 0, + } + } + + InnerSmallVec::Heap(ref elements) => { + let lhs = elements + .iter() + .filter(|x| !other.contains(x)) + .collect::>(); + SmallDifference { + data: lhs, + index: 0, + } + } + } + } + + pub fn symmetric_difference<'a>( + &'a self, + other: &'a Self, + ) -> SmallSymmetricDifference<'a, A::Item> { + match &self.inner { + InnerSmallVec::Stack(ref elements) => { + let mut lhs = elements + .iter() + .filter(|x| !other.contains(x)) + .collect::>(); + let mut rhs = other + .iter() + .filter(|x| !elements.contains(x)) + .collect::>(); + lhs.append(&mut rhs); + SmallSymmetricDifference { + data: lhs, + index: 0, + } + } + + InnerSmallVec::Heap(ref elements) => { + let mut lhs = elements + .iter() + .filter(|x| other.contains(x)) + .collect::>(); + let mut rhs = other + .iter() + .filter(|x| elements.contains(x)) + .collect::>(); + lhs.append(&mut rhs); + SmallSymmetricDifference { + data: lhs, + index: 0, + } + } + } + } +} + +pub struct SmallDrain { + data: Vec, + index: usize, +} + +impl Iterator for SmallDrain { + type Item = T; + + fn next(&mut self) -> Option { + if self.index == self.data.len() { + None + } else { + let ptr = self.data.as_ptr(); + self.index += 1; + unsafe { Some(std::ptr::read(ptr.add(self.index - 1))) } + } + } +} + +pub struct SmallIntersection<'a, T> { + data: Vec<&'a T>, + index: usize, +} + +impl<'a, T> Iterator for SmallIntersection<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index == self.data.len() { + None + } else { + let ptr = self.data.as_ptr(); + self.index += 1; + unsafe { Some(std::ptr::read(ptr.add(self.index - 1))) } + } + } +} + +pub struct SmallUnion<'a, T> { + data: Vec<&'a T>, + index: usize, +} + +impl<'a, T> Iterator for SmallUnion<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index == self.data.len() { + None + } else { + let ptr = self.data.as_ptr(); + self.index += 1; + unsafe { Some(std::ptr::read(ptr.add(self.index - 1))) } + } + } +} + +pub struct SmallDifference<'a, T> { + data: Vec<&'a T>, + index: usize, +} + +impl<'a, T> Iterator for SmallDifference<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index == self.data.len() { + None + } else { + let ptr = self.data.as_ptr(); + self.index += 1; + unsafe { Some(std::ptr::read(ptr.add(self.index - 1))) } + } + } +} + +pub struct SmallSymmetricDifference<'a, T> { + data: Vec<&'a T>, + index: usize, +} + +impl<'a, T> Iterator for SmallSymmetricDifference<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.index == self.data.len() { + None + } else { + let ptr = self.data.as_ptr(); + self.index += 1; + unsafe { Some(std::ptr::read(ptr.add(self.index - 1))) } + } } } @@ -106,7 +494,7 @@ where { fn clone(&self) -> SmallSet { SmallSet { - elements: self.elements.clone(), + inner: self.inner.clone(), } } } @@ -116,87 +504,53 @@ where A::Item: PartialEq + Eq + fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.elements.fmt(f) + match &self.inner { + InnerSmallVec::Stack(elements) => write!(f, "{:?}", elements.as_slice()), + InnerSmallVec::Heap(elements) => write!(f, "{:?}", elements), + } } } impl FromIterator for SmallSet where - A::Item: PartialEq + Eq, + A::Item: PartialEq + Eq + Hash, { fn from_iter(iter: T) -> Self where T: IntoIterator, { - SmallSet { - elements: SmallVec::from_iter(iter), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use std::fmt::Write; - - #[test] - fn test_basic_set() { - let mut s: SmallSet<[u32; 2]> = SmallSet::new(); - assert!(s.insert(1) == true); - assert!(s.insert(2) == true); - assert!(s.insert(2) == false); - assert!(s.insert(3) == true); - assert!(s.insert(2) == false); - assert!(s.insert(3) == false); - assert!(s.contains(&1)); - assert!(s.contains(&2)); - assert!(s.contains(&3)); - assert!(!s.contains(&4)); - assert!(s.len() == 3); - assert!(s.iter().map(|r| *r).collect::>() == vec![1, 2, 3]); - s.clear(); - assert!(!s.contains(&1)); - } - - #[test] - fn test_remove() { - let mut s: SmallSet<[u32; 2]> = SmallSet::new(); - assert!(s.insert(1) == true); - assert!(s.insert(2) == true); - assert!(s.len() == 2); - assert!(s.contains(&1)); - assert!(s.remove(&1) == true); - assert!(s.remove(&1) == false); - assert!(s.len() == 1); - assert!(!s.contains(&1)); - assert!(s.insert(1) == true); - assert!(s.iter().map(|r| *r).collect::>() == vec![2, 1]); - } - - #[test] - fn test_clone() { - let mut s: SmallSet<[u32; 2]> = SmallSet::new(); - s.insert(1); - s.insert(2); - let c = s.clone(); - assert!(c.contains(&1)); - assert!(c.contains(&2)); - assert!(!c.contains(&3)); - } - - #[test] - fn test_debug() { - let mut s: SmallSet<[u32; 2]> = SmallSet::new(); - s.insert(1); - s.insert(2); - let mut buf = String::new(); - write!(buf, "{:?}", s).unwrap(); - assert!(&buf == "[1, 2]"); - } - - #[test] - fn test_fromiter() { - let s: SmallSet<[usize; 4]> = vec![1, 2, 3, 4].into_iter().collect(); - assert!(s.len() == 4); + iter.into_iter().fold(SmallSet::new(), |mut acc, x| { + acc.insert(x); + acc + }) + } +} + +pub struct SmallIter<'a, A: Array> +where + A::Item: PartialEq + Eq + Hash + 'a, +{ + inner: InnerSmallIter<'a, A>, +} + +pub enum InnerSmallIter<'a, A: Array> +where + A::Item: PartialEq + Eq + Hash + 'a, +{ + Stack(std::slice::Iter<'a, A::Item>), + Heap(std::collections::hash_set::Iter<'a, A::Item>), +} + +impl<'a, A: Array> Iterator for SmallIter<'a, A> +where + A::Item: PartialEq + Eq + Hash + 'a, +{ + type Item = &'a A::Item; + + fn next(&mut self) -> Option { + match &mut self.inner { + InnerSmallIter::Stack(ref mut iter) => iter.next(), + InnerSmallIter::Heap(ref mut iter) => iter.next(), + } } } diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..7d23d08 --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,230 @@ +extern crate smallset; + +use smallset::SmallSet; +use std::fmt::Write; +use std::hash::{Hash, Hasher}; + +#[test] +fn test_basic_set() { + let mut s: SmallSet<[u32; 2]> = SmallSet::new(); + assert_eq!(s.insert(1), true); + assert_eq!(s.insert(2), true); + assert_eq!(s.insert(2), false); + assert_eq!(s.insert(3), true); + assert_eq!(s.insert(2), false); + assert_eq!(s.insert(3), false); + assert!(s.contains(&1)); + assert!(s.contains(&2)); + assert!(s.contains(&3)); + assert!(!s.contains(&4)); + assert_eq!(s.len(), 3); + assert!(s + .iter() + .map(|r| *r) + .collect::>() + .iter() + .all(|x| vec![1, 2, 3].contains(x))); + s.clear(); + assert!(!s.contains(&1)); +} + +#[test] +fn test_remove() { + let mut s: SmallSet<[u32; 2]> = SmallSet::new(); + assert_eq!(s.insert(1), true); + assert_eq!(s.insert(2), true); + assert_eq!(s.len(), 2); + assert!(s.contains(&1)); + assert_eq!(s.remove(&1), true); + assert_eq!(s.remove(&1), false); + assert_eq!(s.len(), 1); + assert!(!s.contains(&1)); + assert_eq!(s.insert(1), true); + assert!(s + .iter() + .map(|r| *r) + .collect::>() + .iter() + .all(|x| vec![1, 2, 3].contains(x))); +} + +#[test] +fn test_clone() { + let mut s: SmallSet<[u32; 2]> = SmallSet::new(); + s.insert(1); + s.insert(2); + let c = s.clone(); + assert!(c.contains(&1)); + assert!(c.contains(&2)); + assert!(!c.contains(&3)); +} + +#[test] +fn test_debug_small() { + let mut s: SmallSet<[u32; 2]> = SmallSet::new(); + s.insert(1); + s.insert(2); + let mut buf = String::new(); + write!(buf, "{:?}", s).unwrap(); + assert_eq!(&buf, "[1, 2]"); +} + +#[test] +fn test_from_iter() { + let s: SmallSet<[usize; 4]> = vec![1, 2, 3, 4].into_iter().collect(); + assert_eq!(s.len(), 4); +} + +#[test] +fn test_replace() { + struct RingOf7 { + pub value: u32, + } + + impl PartialEq for RingOf7 { + fn eq(&self, other: &Self) -> bool { + self.value % 7 == other.value % 7 + } + + fn ne(&self, other: &Self) -> bool { + self.value % 7 != other.value % 7 + } + } + + impl From for u32 { + fn from(value: RingOf7) -> Self { + value.value + } + } + + impl Hash for RingOf7 { + fn hash(&self, state: &mut H) { + self.value.hash(state) + } + } + + impl Eq for RingOf7 {} + + let mut lhs = SmallSet::<[RingOf7; 4]>::new(); + lhs.insert(RingOf7 { value: 1 }); + lhs.insert(RingOf7 { value: 2 }); + lhs.insert(RingOf7 { value: 3 }); + lhs.insert(RingOf7 { value: 4 }); + + lhs.replace(RingOf7 { value: 8 }); + lhs.replace(RingOf7 { value: 9 }); + lhs.replace(RingOf7 { value: 10 }); + lhs.replace(RingOf7 { value: 11 }); + + let expected = vec![8, 9, 10, 11]; + assert!(lhs + .iter() + .map(|x| x.value) + .collect::>() + .iter() + .zip(expected.iter()) + .all(|(lhs, rhs)| lhs == rhs)); +} + +#[test] +fn test_eq() { + let mut lhs = SmallSet::<[u32; 4]>::new(); + lhs.insert(1); + lhs.insert(2); + + let mut rhs = SmallSet::<[u32; 4]>::new(); + rhs.insert(1); + rhs.insert(2); + + assert_eq!(lhs, rhs); +} + +#[test] +fn test_intersection() { + let mut lhs = SmallSet::<[u32; 4]>::new(); + lhs.insert(1); + lhs.insert(3); + lhs.insert(5); + lhs.insert(4); + lhs.insert(8); + lhs.insert(10); + + let mut rhs = SmallSet::<[u32; 4]>::new(); + rhs.insert(4); + rhs.insert(8); + rhs.insert(10); + + assert!(lhs.intersection(&rhs).all(|x| x % 2 == 0)); +} + +#[test] +fn test_union() { + let mut lhs = SmallSet::<[u32; 4]>::new(); + lhs.insert(1); + lhs.insert(2); + lhs.insert(3); + lhs.insert(4); + + let mut rhs = SmallSet::<[u32; 4]>::new(); + rhs.insert(3); + rhs.insert(4); + rhs.insert(5); + rhs.insert(6); + + let union = lhs.union(&rhs).collect::>(); + let expected = vec![1, 2, 3, 4, 5, 6]; + assert_eq!(union.len(), expected.len()); + assert!(expected + .iter() + .collect::>() + .iter() + .all(|x| union.contains(x))); +} + +#[test] +fn test_difference() { + let mut lhs = SmallSet::<[u32; 4]>::new(); + lhs.insert(1); + lhs.insert(2); + lhs.insert(3); + lhs.insert(4); + + let mut rhs = SmallSet::<[u32; 4]>::new(); + rhs.insert(3); + rhs.insert(4); + rhs.insert(5); + rhs.insert(6); + + let union = lhs.difference(&rhs).collect::>(); + let expected = vec![1, 2]; + assert_eq!(union.len(), expected.len()); + assert!(expected + .iter() + .collect::>() + .iter() + .all(|x| union.contains(x))); +} + +#[test] +fn test_symmetric_difference() { + let mut lhs = SmallSet::<[u32; 4]>::new(); + lhs.insert(1); + lhs.insert(2); + lhs.insert(3); + lhs.insert(4); + + let mut rhs = SmallSet::<[u32; 4]>::new(); + rhs.insert(3); + rhs.insert(4); + rhs.insert(5); + rhs.insert(6); + + let symmetric_difference = lhs.symmetric_difference(&rhs).collect::>(); + let expected = vec![1, 2, 5, 6]; + assert_eq!(symmetric_difference.len(), expected.len()); + assert!(expected + .iter() + .collect::>() + .iter() + .all(|x| { symmetric_difference.contains(x) })); +}