diff --git a/src/lib.rs b/src/lib.rs index 4f06c071f..c6a62b652 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1065,10 +1065,9 @@ pub trait Itertools : Iterator { /// ); /// ``` #[inline] - fn merge_join_by(self, other: J, cmp_fn: F) -> MergeJoinBy + fn merge_join_by(self, other: J, cmp_fn: F) -> MergeJoinBy> where J: IntoIterator, F: FnMut(&Self::Item, &J::Item) -> T, - T: merge_join::OrderingOrBool, Self: Sized { merge_join_by(self, other, cmp_fn) diff --git a/src/merge_join.rs b/src/merge_join.rs index 9bb2fdc9f..360cb8dd6 100644 --- a/src/merge_join.rs +++ b/src/merge_join.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use std::iter::Fuse; // use std::iter::FusedIterator; use std::fmt; +use std::marker::PhantomData; use either::Either; @@ -11,26 +12,19 @@ use crate::size_hint::{self, SizeHint}; #[cfg(doc)] use crate::Itertools; -pub trait MergePredicate { - type Out; - fn merge_pred(&mut self, left: &L, right: &R) -> Self::Out; -} +#[derive(Clone, Debug)] +pub struct MergeLte; -impl T> MergePredicate for F { - type Out = T; - fn merge_pred(&mut self, left: &L, right: &R) -> T { - self(left, right) - } +#[derive(Clone, Debug)] +pub struct MergeFuncLR { + f: F, + _t: PhantomData, } #[derive(Clone, Debug)] -pub struct MergeLte; - -impl MergePredicate for MergeLte { - type Out = bool; - fn merge_pred(&mut self, left: &T, right: &T) -> bool { - left <= right - } +pub struct MergeFuncT { + f: F, + _t: PhantomData, } /// An iterator adaptor that merges the two base iterators in ascending order. @@ -57,7 +51,11 @@ pub fn merge(i: I, j: J) -> Merge<::IntoIter, , I::Item: PartialOrd { - merge_by_new(i, j, MergeLte) + MergeJoinBy { + left: put_back(i.into_iter().fuse()), + right: put_back(j.into_iter().fuse()), + cmp_fn: MergeFuncT { f: MergeLte, _t: PhantomData }, + } } /// An iterator adaptor that merges the two base iterators in ascending order. @@ -66,18 +64,18 @@ pub fn merge(i: I, j: J) -> Merge<::IntoIter, = MergeJoinBy; +pub type MergeBy = MergeJoinBy>; /// Create a `MergeBy` iterator. pub fn merge_by_new(a: I, b: J, cmp: F) -> MergeBy where I: IntoIterator, J: IntoIterator, - F: MergePredicate, + F: FnMut(&I::Item, &I::Item) -> bool, { MergeJoinBy { left: put_back(a.into_iter().fuse()), right: put_back(b.into_iter().fuse()), - cmp_fn: cmp, + cmp_fn: MergeFuncT { f: cmp, _t: PhantomData }, } } @@ -91,16 +89,15 @@ pub fn merge_by_new(a: I, b: J, cmp: F) -> MergeBy(left: I, right: J, cmp_fn: F) - -> MergeJoinBy + -> MergeJoinBy> where I: IntoIterator, J: IntoIterator, F: FnMut(&I::Item, &J::Item) -> T, - T: OrderingOrBool, { MergeJoinBy { left: put_back(left.into_iter().fuse()), right: put_back(right.into_iter().fuse()), - cmp_fn, + cmp_fn: MergeFuncLR { f: cmp_fn, _t: PhantomData }, } } @@ -108,24 +105,26 @@ pub fn merge_join_by(left: I, right: J, cmp_fn: F) /// /// See [`.merge_join_by()`](crate::Itertools::merge_join_by) for more information. #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] -pub struct MergeJoinBy { +pub struct MergeJoinBy { left: PutBack>, right: PutBack>, cmp_fn: F, } -pub trait OrderingOrBool { +pub trait MergePredicate { + type Out; type MergeResult; fn left(left: L) -> Self::MergeResult; fn right(right: R) -> Self::MergeResult; // "merge" never returns (Some(...), Some(...), ...) so Option> // is appealing but it is always followed by two put_backs, so we think the compiler is // smart enough to optimize it. Or we could move put_backs into "merge". - fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult); + fn merge(&mut self, left: L, right: R) -> (Option, Option, Self::MergeResult); fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint; } -impl OrderingOrBool for Ordering { +impl Ordering> MergePredicate for MergeFuncLR { + type Out = Ordering; type MergeResult = EitherOrBoth; fn left(left: L) -> Self::MergeResult { EitherOrBoth::Left(left) @@ -133,8 +132,8 @@ impl OrderingOrBool for Ordering { fn right(right: R) -> Self::MergeResult { EitherOrBoth::Right(right) } - fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult) { - match self { + fn merge(&mut self, left: L, right: R) -> (Option, Option, Self::MergeResult) { + match (self.f)(&left, &right) { Ordering::Equal => (None, None, EitherOrBoth::Both(left, right)), Ordering::Less => (None, Some(right), EitherOrBoth::Left(left)), Ordering::Greater => (Some(left), None, EitherOrBoth::Right(right)), @@ -152,7 +151,8 @@ impl OrderingOrBool for Ordering { } } -impl OrderingOrBool for bool { +impl bool> MergePredicate for MergeFuncLR { + type Out = bool; type MergeResult = Either; fn left(left: L) -> Self::MergeResult { Either::Left(left) @@ -160,8 +160,8 @@ impl OrderingOrBool for bool { fn right(right: R) -> Self::MergeResult { Either::Right(right) } - fn merge(self, left: L, right: R) -> (Option, Option, Self::MergeResult) { - if self { + fn merge(&mut self, left: L, right: R) -> (Option, Option, Self::MergeResult) { + if (self.f)(&left, &right) { (None, Some(right), Either::Left(left)) } else { (Some(left), None, Either::Right(right)) @@ -173,7 +173,30 @@ impl OrderingOrBool for bool { } } -impl OrderingOrBool for bool { +impl bool> MergePredicate for MergeFuncT { + type Out = bool; + type MergeResult = T; + fn left(left: T) -> Self::MergeResult { + left + } + fn right(right: T) -> Self::MergeResult { + right + } + fn merge(&mut self, left: T, right: T) -> (Option, Option, Self::MergeResult) { + if (self.f)(&left, &right) { + (None, Some(right), left) + } else { + (Some(left), None, right) + } + } + fn size_hint(left: SizeHint, right: SizeHint) -> SizeHint { + // Not ExactSizeIterator because size may be larger than usize + size_hint::add(left, right) + } +} + +impl MergePredicate for MergeFuncT { + type Out = bool; type MergeResult = T; fn left(left: T) -> Self::MergeResult { left @@ -181,8 +204,8 @@ impl OrderingOrBool for bool { fn right(right: T) -> Self::MergeResult { right } - fn merge(self, left: T, right: T) -> (Option, Option, Self::MergeResult) { - if self { + fn merge(&mut self, left: T, right: T) -> (Option, Option, Self::MergeResult) { + if left <= right { (None, Some(right), left) } else { (Some(left), None, right) @@ -194,7 +217,7 @@ impl OrderingOrBool for bool { } } -impl Clone for MergeJoinBy +impl Clone for MergeJoinBy where I: Iterator, J: Iterator, PutBack>: Clone, @@ -204,7 +227,7 @@ impl Clone for MergeJoinBy clone_fields!(left, right, cmp_fn); } -impl fmt::Debug for MergeJoinBy +impl fmt::Debug for MergeJoinBy where I: Iterator + fmt::Debug, I::Item: fmt::Debug, J: Iterator + fmt::Debug, @@ -213,21 +236,20 @@ impl fmt::Debug for MergeJoinBy debug_fmt_fields!(MergeJoinBy, left, right); } -impl Iterator for MergeJoinBy +impl Iterator for MergeJoinBy where I: Iterator, J: Iterator, F: MergePredicate, - T: OrderingOrBool, { - type Item = T::MergeResult; + type Item = F::MergeResult; fn next(&mut self) -> Option { match (self.left.next(), self.right.next()) { (None, None) => None, - (Some(left), None) => Some(T::left(left)), - (None, Some(right)) => Some(T::right(right)), + (Some(left), None) => Some(F::left(left)), + (None, Some(right)) => Some(F::right(right)), (Some(left), Some(right)) => { - let (left, right, next) = self.cmp_fn.merge_pred(&left, &right).merge(left, right); + let (left, right, next) = self.cmp_fn.merge(left, right); if let Some(left) = left { self.left.put_back(left); } @@ -240,7 +262,7 @@ impl Iterator for MergeJoinBy } fn size_hint(&self) -> SizeHint { - T::size_hint(self.left.size_hint(), self.right.size_hint()) + F::size_hint(self.left.size_hint(), self.right.size_hint()) } fn count(mut self) -> usize { @@ -252,7 +274,7 @@ impl Iterator for MergeJoinBy (None, Some(_right)) => break count + 1 + self.right.into_parts().1.count(), (Some(left), Some(right)) => { count += 1; - let (left, right, _) = self.cmp_fn.merge_pred(&left, &right).merge(left, right); + let (left, right, _) = self.cmp_fn.merge(left, right); if let Some(left) = left { self.left.put_back(left); } @@ -270,17 +292,17 @@ impl Iterator for MergeJoinBy match (self.left.next(), self.right.next()) { (None, None) => break previous_element, (Some(left), None) => { - break Some(T::left( + break Some(F::left( self.left.into_parts().1.last().unwrap_or(left), )) } (None, Some(right)) => { - break Some(T::right( + break Some(F::right( self.right.into_parts().1.last().unwrap_or(right), )) } (Some(left), Some(right)) => { - let (left, right, elem) = self.cmp_fn.merge_pred(&left, &right).merge(left, right); + let (left, right, elem) = self.cmp_fn.merge(left, right); if let Some(left) = left { self.left.put_back(left); } @@ -301,10 +323,10 @@ impl Iterator for MergeJoinBy n -= 1; match (self.left.next(), self.right.next()) { (None, None) => break None, - (Some(_left), None) => break self.left.nth(n).map(T::left), - (None, Some(_right)) => break self.right.nth(n).map(T::right), + (Some(_left), None) => break self.left.nth(n).map(F::left), + (None, Some(_right)) => break self.right.nth(n).map(F::right), (Some(left), Some(right)) => { - let (left, right, _) = self.cmp_fn.merge_pred(&left, &right).merge(left, right); + let (left, right, _) = self.cmp_fn.merge(left, right); if let Some(left) = left { self.left.put_back(left); }