diff --git a/cli/src/torii/routing.rs b/cli/src/torii/routing.rs index 9d2d8481504..e0e63653db5 100644 --- a/cli/src/torii/routing.rs +++ b/cli/src/torii/routing.rs @@ -116,7 +116,7 @@ pub(crate) async fn handle_instructions( #[allow(clippy::map_err_ignore)] queue .push(transaction, &sumeragi.wsv_mutex_access()) - .map_err(|(tx, err)| { + .map_err(|queue::Failure { tx, err }| { iroha_logger::warn!( tx_hash=%tx.hash(), ?err, "Failed to push into queue" diff --git a/config/iroha_test_config.json b/config/iroha_test_config.json index 0343401c7bc..33aa7860c3c 100644 --- a/config/iroha_test_config.json +++ b/config/iroha_test_config.json @@ -57,6 +57,7 @@ "QUEUE": { "MAXIMUM_TRANSACTIONS_IN_BLOCK": 8192, "MAXIMUM_TRANSACTIONS_IN_QUEUE": 65536, + "MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER": 65536, "TRANSACTION_TIME_TO_LIVE_MS": 86400000, "FUTURE_THRESHOLD_MS": 1000 }, diff --git a/config/src/queue.rs b/config/src/queue.rs index 37a0954a0e1..0ab6d51f2a6 100644 --- a/config/src/queue.rs +++ b/config/src/queue.rs @@ -5,6 +5,7 @@ use serde::{Deserialize, Serialize}; const DEFAULT_MAXIMUM_TRANSACTIONS_IN_BLOCK: u32 = 2_u32.pow(9); const DEFAULT_MAXIMUM_TRANSACTIONS_IN_QUEUE: u32 = 2_u32.pow(16); +const DEFAULT_MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER: u32 = 2_u32.pow(16); // 24 hours const DEFAULT_TRANSACTION_TIME_TO_LIVE_MS: u64 = 24 * 60 * 60 * 1000; const DEFAULT_FUTURE_THRESHOLD_MS: u64 = 1000; @@ -18,6 +19,8 @@ pub struct Configuration { pub maximum_transactions_in_block: u32, /// The upper limit of the number of transactions waiting in the queue. pub maximum_transactions_in_queue: u32, + /// The upper limit of the number of transactions waiting for more signatures. + pub maximum_transactions_in_signature_buffer: u32, /// The transaction will be dropped after this time if it is still in the queue. pub transaction_time_to_live_ms: u64, /// The threshold to determine if a transaction has been tampered to have a future timestamp. @@ -29,6 +32,9 @@ impl Default for ConfigurationProxy { Self { maximum_transactions_in_block: Some(DEFAULT_MAXIMUM_TRANSACTIONS_IN_BLOCK), maximum_transactions_in_queue: Some(DEFAULT_MAXIMUM_TRANSACTIONS_IN_QUEUE), + maximum_transactions_in_signature_buffer: Some( + DEFAULT_MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER, + ), transaction_time_to_live_ms: Some(DEFAULT_TRANSACTION_TIME_TO_LIVE_MS), future_threshold_ms: Some(DEFAULT_FUTURE_THRESHOLD_MS), } @@ -46,11 +52,12 @@ pub mod tests { ( maximum_transactions_in_block in prop::option::of(Just(DEFAULT_MAXIMUM_TRANSACTIONS_IN_BLOCK)), maximum_transactions_in_queue in prop::option::of(Just(DEFAULT_MAXIMUM_TRANSACTIONS_IN_QUEUE)), + maximum_transactions_in_signature_buffer in prop::option::of(Just(DEFAULT_MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER)), transaction_time_to_live_ms in prop::option::of(Just(DEFAULT_TRANSACTION_TIME_TO_LIVE_MS)), future_threshold_ms in prop::option::of(Just(DEFAULT_FUTURE_THRESHOLD_MS)), ) -> ConfigurationProxy { - ConfigurationProxy { maximum_transactions_in_block, maximum_transactions_in_queue, transaction_time_to_live_ms, future_threshold_ms } + ConfigurationProxy { maximum_transactions_in_block, maximum_transactions_in_queue, maximum_transactions_in_signature_buffer, transaction_time_to_live_ms, future_threshold_ms } } } } diff --git a/configs/peer/config.json b/configs/peer/config.json index 617118bc2fd..6e020043b81 100644 --- a/configs/peer/config.json +++ b/configs/peer/config.json @@ -38,6 +38,7 @@ "QUEUE": { "MAXIMUM_TRANSACTIONS_IN_BLOCK": 512, "MAXIMUM_TRANSACTIONS_IN_QUEUE": 65536, + "MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER": 65536, "TRANSACTION_TIME_TO_LIVE_MS": 86400000, "FUTURE_THRESHOLD_MS": 1000 }, diff --git a/core/src/queue.rs b/core/src/queue.rs index 8d203afc051..93da5f98409 100644 --- a/core/src/queue.rs +++ b/core/src/queue.rs @@ -3,8 +3,7 @@ clippy::module_name_repetitions, clippy::std_instead_of_core, clippy::std_instead_of_alloc, - clippy::arithmetic_side_effects, - clippy::expect_used + clippy::arithmetic_side_effects )] use core::time::Duration; @@ -16,7 +15,7 @@ use eyre::{Report, Result}; use iroha_config::queue::Configuration; use iroha_crypto::HashOf; use iroha_data_model::transaction::prelude::*; -use iroha_primitives::must_use::MustUse; +use iroha_primitives::{must_use::MustUse, riffle_iter::RiffleIter}; use rand::seq::IteratorRandom; use thiserror::Error; @@ -27,8 +26,12 @@ use crate::prelude::*; /// Multiple producers, single consumer #[derive(Debug)] pub struct Queue { - /// The queue proper + /// The queue for transactions that passed signature check queue: ArrayQueue>, + /// The queue for transactions that didn't pass signature check and are waiting for additional signatures + /// + /// Second queue is needed to prevent situation when multisig transactions prevent ordinary transactions from being added into the queue + signature_buffer: ArrayQueue>, /// [`VersionedAcceptedTransaction`]s addressed by `Hash`. txs: DashMap, VersionedAcceptedTransaction>, /// The maximum number of transactions in the block @@ -67,13 +70,26 @@ pub enum Error { }, } +/// Failure that can pop up when pushing transaction into the queue +#[derive(Debug)] +pub struct Failure { + /// Transaction failed to be pushed into the queue + pub tx: VersionedAcceptedTransaction, + /// Push failure reason + pub err: Error, +} + impl Queue { /// Makes queue from configuration pub fn from_configuration(cfg: &Configuration) -> Self { Self { queue: ArrayQueue::new(cfg.maximum_transactions_in_queue as usize), + signature_buffer: ArrayQueue::new( + cfg.maximum_transactions_in_signature_buffer as usize, + ), txs: DashMap::new(), - max_txs: cfg.maximum_transactions_in_queue as usize, + max_txs: (cfg.maximum_transactions_in_queue + + cfg.maximum_transactions_in_signature_buffer) as usize, txs_in_block: cfg.maximum_transactions_in_block as usize, tx_time_to_live: Duration::from_millis(cfg.transaction_time_to_live_ms), future_threshold: Duration::from_millis(cfg.future_threshold_ms), @@ -114,7 +130,9 @@ impl Queue { tx: &VersionedAcceptedTransaction, wsv: &WorldStateView, ) -> Result, Error> { - if tx.is_expired(self.tx_time_to_live) { + if tx.is_in_future(self.future_threshold) { + Err(Error::InFuture) + } else if tx.is_expired(self.tx_time_to_live) { Err(Error::Expired) } else if tx.is_in_blockchain(wsv) { Err(Error::InBlockchain) @@ -127,85 +145,76 @@ impl Queue { } } - /// Pushes transaction into queue. + /// Push transaction into queue. /// /// # Errors /// See [`enum@Error`] - #[allow( - clippy::unwrap_in_result, - clippy::expect_used, - clippy::missing_panics_doc - )] pub fn push( &self, tx: VersionedAcceptedTransaction, wsv: &WorldStateView, - ) -> Result<(), (VersionedAcceptedTransaction, Error)> { - if tx.is_in_future(self.future_threshold) { - Err((tx, Error::InFuture)) - } else if let Err(e) = self.check_tx(&tx, wsv) { - Err((tx, e)) - } else if self.txs.len() >= self.max_txs { - Err((tx, Error::Full)) - } else { - let hash = tx.hash(); - let entry = match self.txs.entry(hash) { - Entry::Occupied(mut old_tx) => { - // MST case - old_tx - .get_mut() - .as_mut_v1() - .signatures - .extend(tx.as_v1().signatures.clone()); - return Ok(()); + ) -> Result<(), Failure> { + match self.check_tx(&tx, wsv) { + Err(err) => Err(Failure { tx, err }), + Ok(MustUse(signature_check)) => { + // Get `txs_len` before entry to avoid deadlock + let txs_len = self.txs.len(); + let hash = tx.hash(); + let entry = match self.txs.entry(hash) { + Entry::Occupied(mut old_tx) => { + // MST case + old_tx + .get_mut() + .as_mut_v1() + .signatures + .extend(tx.as_v1().signatures.clone()); + return Ok(()); + } + Entry::Vacant(entry) => entry, + }; + if txs_len >= self.max_txs { + return Err(Failure { + tx, + err: Error::Full, + }); } - Entry::Vacant(entry) => entry, - }; - // Reason for such insertion order is to avoid situation - // when poped from the `queue` hash does not yet has corresponding (hash, tx) record in `txs` - entry.insert(tx); - self.queue.push(hash).map_err(|err_hash| { - let (_, err_tx) = self - .txs - .remove(&err_hash) - .expect("Inserted just before match"); - (err_tx, Error::Full) - }) + // Insert entry first so that the `tx` popped from `queue` will always have a `(hash, tx)` record in `txs`. + entry.insert(tx); + let queue_to_push = if signature_check { + &self.queue + } else { + &self.signature_buffer + }; + queue_to_push.push(hash).map_err(|err_hash| { + let (_, err_tx) = self + .txs + .remove(&err_hash) + .expect("Inserted just before match"); + Failure { + tx: err_tx, + err: Error::Full, + } + }) + } } } - /// Pop single transaction. - /// - /// Records unsigned transaction in `seen`. - #[allow( - clippy::expect_used, - clippy::unwrap_in_result, - clippy::cognitive_complexity - )] - fn pop( + /// Pop single transaction from the signature buffer. Record all visited and not removed transactions in `seen`. + fn pop_from_signature_buffer( &self, seen: &mut Vec>, wsv: &WorldStateView, ) -> Option { loop { - let hash = self.queue.pop()?; + let hash = self.signature_buffer.pop()?; let entry = match self.txs.entry(hash) { Entry::Occupied(entry) => entry, - // As practice shows this code is not `unreachable!()`. - // When transactions are submitted quickly it can be reached. + // FIXME: Reachable under high load. Investigate, see if it's a problem. Entry::Vacant(_) => continue, }; - if self.check_tx(entry.get(), wsv).is_err() { - entry.remove_entry(); - continue; - } match self.check_tx(entry.get(), wsv) { - Err(_) => { - entry.remove_entry(); - continue; - } Ok(MustUse(signature_check)) => { // Transactions are not removed from the queue until expired or committed seen.push(hash); @@ -213,10 +222,39 @@ impl Queue { return Some(entry.get().clone()); } } + Err(_) => { + entry.remove_entry(); + } } } } + /// Pop single transaction from the queue. Record all visited and not removed transactions in `seen`. + fn pop_from_queue( + &self, + seen: &mut Vec>, + wsv: &WorldStateView, + ) -> Option { + loop { + let hash = self.queue.pop()?; + let entry = match self.txs.entry(hash) { + Entry::Occupied(entry) => entry, + // As practice shows this code is not `unreachable!()`. + // When transactions are submitted quickly it can be reached. + Entry::Vacant(_) => continue, + }; + + if !self.is_pending(entry.get(), wsv) { + entry.remove_entry(); + continue; + } + + // Transactions are not removed from the queue until expired or committed + seen.push(hash); + return Some(entry.get().clone()); + } + } + /// Return the number of transactions in the queue. pub fn tx_len(&self) -> usize { self.txs.len() @@ -247,21 +285,33 @@ impl Queue { return; } - let mut seen = Vec::new(); + let mut seen_queue = Vec::new(); + let mut seen_waiting_buffer = Vec::new(); + + let txs_from_queue = core::iter::from_fn(|| self.pop_from_queue(&mut seen_queue, wsv)); + let txs_from_waiting_buffer = + core::iter::from_fn(|| self.pop_from_signature_buffer(&mut seen_waiting_buffer, wsv)); let transactions_hashes: HashSet> = transactions .iter() .map(VersionedAcceptedTransaction::hash) .collect(); - let out = std::iter::from_fn(|| self.pop(&mut seen, wsv)) + let txs = txs_from_queue + .riffle(txs_from_waiting_buffer) .filter(|tx| !transactions_hashes.contains(&tx.hash())) .take(self.txs_in_block - transactions.len()); - transactions.extend(out); - - #[allow(clippy::expect_used)] - seen.into_iter() - .try_for_each(|hash| self.queue.push(hash)) - .expect("As we never exceed the number of transactions pending"); + transactions.extend(txs); + + [ + (seen_queue, &self.queue), + (seen_waiting_buffer, &self.signature_buffer), + ] + .into_iter() + .for_each(|(seen, queue)| { + seen.into_iter() + .try_for_each(|hash| queue.push(hash)) + .expect("Exceeded the number of transactions pending") + }) } } @@ -372,10 +422,185 @@ mod tests { assert!(matches!( queue.push(accepted_tx("alice@wonderland", 100_000, key_pair), &wsv), - Err((_, Error::Full)) + Err(Failure { + err: Error::Full, + .. + }) )); } + #[test] + fn push_tx_when_signature_buffer_is_full() { + let max_txs_in_waiting_buffer = 10; + + let alice_key_pairs = [KeyPair::generate().unwrap(), KeyPair::generate().unwrap()]; + let bob_key_pair = KeyPair::generate().unwrap(); + let kura = Kura::blank_kura_for_testing(); + let wsv = { + let domain_id = DomainId::from_str("wonderland").expect("Valid"); + let mut domain = Domain::new(domain_id.clone()).build(); + let alice_id = AccountId::from_str("alice@wonderland").expect("Valid"); + let bob_id = AccountId::from_str("bob@wonderland").expect("Valid"); + let mut alice = Account::new( + alice_id, + alice_key_pairs.iter().map(KeyPair::public_key).cloned(), + ) + .build(); + alice.set_signature_check_condition(SignatureCheckCondition( + ContainsAll::new( + EvaluatesTo::new_unchecked( + ContextValue::new( + Name::from_str(TRANSACTION_SIGNATORIES_VALUE) + .expect("TRANSACTION_SIGNATORIES_VALUE should be valid."), + ) + .into(), + ), + EvaluatesTo::new_unchecked( + ContextValue::new( + Name::from_str(ACCOUNT_SIGNATORIES_VALUE) + .expect("ACCOUNT_SIGNATORIES_VALUE should be valid."), + ) + .into(), + ), + ) + .into(), + )); + let bob = Account::new(bob_id, [bob_key_pair.public_key().clone()]).build(); + assert!(domain.add_account(alice).is_none()); + assert!(domain.add_account(bob).is_none()); + Arc::new(WorldStateView::new( + World::with([domain], PeersIds::new()), + kura.clone(), + )) + }; + + let queue = Queue::from_configuration(&Configuration { + maximum_transactions_in_block: 2, + transaction_time_to_live_ms: 100_000, + maximum_transactions_in_signature_buffer: max_txs_in_waiting_buffer, + ..ConfigurationProxy::default() + .build() + .expect("Default queue config should always build") + }); + + // Fill waiting buffer with multisig transactions + for _ in 0..max_txs_in_waiting_buffer { + queue + .push( + accepted_tx("alice@wonderland", 100_000, alice_key_pairs[0].clone()), + &wsv, + ) + .expect("Failed to push tx into queue"); + thread::sleep(Duration::from_millis(10)); + } + + // Check that signature buffer is full + assert!(matches!( + queue.push( + accepted_tx("alice@wonderland", 100_000, alice_key_pairs[0].clone()), + &wsv + ), + Err(Failure { + err: Error::Full, + .. + }) + )); + + // Check that ordinary transactions can still be pushed into the queue + assert!(queue + .push( + accepted_tx("bob@wonderland", 100_000, bob_key_pair.clone()), + &wsv, + ) + .is_ok()) + } + + #[test] + fn push_multisig_tx_when_queue_is_full() { + let max_txs_in_queue = 10; + + let alice_key_pairs = [KeyPair::generate().unwrap(), KeyPair::generate().unwrap()]; + let bob_key_pair = KeyPair::generate().unwrap(); + let kura = Kura::blank_kura_for_testing(); + let wsv = { + let domain_id = DomainId::from_str("wonderland").expect("Valid"); + let mut domain = Domain::new(domain_id.clone()).build(); + let alice_id = AccountId::from_str("alice@wonderland").expect("Valid"); + let bob_id = AccountId::from_str("bob@wonderland").expect("Valid"); + let mut alice = Account::new( + alice_id, + alice_key_pairs.iter().map(KeyPair::public_key).cloned(), + ) + .build(); + alice.set_signature_check_condition(SignatureCheckCondition( + ContainsAll::new( + EvaluatesTo::new_unchecked( + ContextValue::new( + Name::from_str(TRANSACTION_SIGNATORIES_VALUE) + .expect("TRANSACTION_SIGNATORIES_VALUE should be valid."), + ) + .into(), + ), + EvaluatesTo::new_unchecked( + ContextValue::new( + Name::from_str(ACCOUNT_SIGNATORIES_VALUE) + .expect("ACCOUNT_SIGNATORIES_VALUE should be valid."), + ) + .into(), + ), + ) + .into(), + )); + let bob = Account::new(bob_id, [bob_key_pair.public_key().clone()]).build(); + assert!(domain.add_account(alice).is_none()); + assert!(domain.add_account(bob).is_none()); + Arc::new(WorldStateView::new( + World::with([domain], PeersIds::new()), + kura.clone(), + )) + }; + + let queue = Queue::from_configuration(&Configuration { + maximum_transactions_in_block: 2, + transaction_time_to_live_ms: 100_000, + maximum_transactions_in_queue: max_txs_in_queue, + ..ConfigurationProxy::default() + .build() + .expect("Default queue config should always build") + }); + + // Fill queue with ordinary transactions + for _ in 0..max_txs_in_queue { + queue + .push( + accepted_tx("bob@wonderland", 100_000, bob_key_pair.clone()), + &wsv, + ) + .expect("Failed to push tx into queue"); + thread::sleep(Duration::from_millis(10)); + } + + // Check that queue is full + assert!(matches!( + queue.push( + accepted_tx("bob@wonderland", 100_000, bob_key_pair.clone()), + &wsv + ), + Err(Failure { + err: Error::Full, + .. + }) + )); + + // Check that multisig transactions can still be pushed into the queue + assert!(queue + .push( + accepted_tx("alice@wonderland", 100_000, alice_key_pairs[0].clone()), + &wsv, + ) + .is_ok()) + } + #[test] fn push_tx_signature_condition_failure() { let max_txs_in_queue = 10; @@ -410,7 +635,10 @@ mod tests { assert!(matches!( queue.push(accepted_tx("alice@wonderland", 100_000, key_pair), &wsv), - Err((_, Error::SignatureCondition { .. })) + Err(Failure { + err: Error::SignatureCondition { .. }, + .. + }) )); } @@ -571,7 +799,10 @@ mod tests { }); assert!(matches!( queue.push(tx, &wsv), - Err((_, Error::InBlockchain)) + Err(Failure { + err: Error::InBlockchain, + .. + }) )); assert_eq!(queue.txs.len(), 0); } @@ -711,8 +942,10 @@ mod tests { let tx = accepted_tx("alice@wonderland", 100_000, alice_key.clone()); match queue_arc_clone.push(tx, &wsv_clone) { Ok(()) => (), - Err((_, Error::Full)) => (), - Err((_, err)) => panic!("{}", err), + Err(Failure { + err: Error::Full, .. + }) => (), + Err(Failure { err, .. }) => panic!("{err}"), } } }) @@ -768,7 +1001,13 @@ mod tests { assert!(queue.push(tx.clone(), &wsv).is_ok()); // tamper timestamp tx.as_mut_v1().payload.creation_time += 2 * future_threshold_ms; - assert!(matches!(queue.push(tx, &wsv), Err((_, Error::InFuture)))); + assert!(matches!( + queue.push(tx, &wsv), + Err(Failure { + err: Error::InFuture, + .. + }) + )); assert_eq!(queue.txs.len(), 1); } } diff --git a/core/src/sumeragi/main_loop.rs b/core/src/sumeragi/main_loop.rs index f72659cff44..6b62b366c31 100644 --- a/core/src/sumeragi/main_loop.rs +++ b/core/src/sumeragi/main_loop.rs @@ -484,10 +484,13 @@ fn enqueue_transaction( match VersionedAcceptedTransaction::from_transaction(tx, &sumeragi.transaction_limits) { Ok(tx) => match sumeragi.queue.push(tx, wsv) { Ok(_) => {} - Err((tx, crate::queue::Error::InBlockchain)) => { + Err(crate::queue::Failure { + tx, + err: crate::queue::Error::InBlockchain, + }) => { debug!(tx_hash = %tx.hash(), "Transaction already in blockchain, ignoring...") } - Err((tx, err)) => { + Err(crate::queue::Failure { tx, err }) => { error!(%addr, ?err, tx_hash = %tx.hash(), "Failed to enqueue transaction.") } }, diff --git a/core/src/tx.rs b/core/src/tx.rs index 61bec95258e..0cfb8726233 100644 --- a/core/src/tx.rs +++ b/core/src/tx.rs @@ -290,6 +290,9 @@ impl VersionedAcceptedTransaction { /// Checks that the signatures of this transaction satisfy the signature condition specified in the account. /// + /// Note that `check_signature_condition` does not verify signatures. + /// Signature verification is done when transaction transit from `SignedTransaction` to `AcceptedTransaction` state. + /// /// # Errors /// Can fail if signature condition account fails or if account is not found pub fn check_signature_condition(&self, wsv: &WorldStateView) -> Result> { @@ -347,9 +350,11 @@ impl AcceptedTransaction { /// Checks that the signatures of this transaction satisfy the signature condition specified in the account. /// + /// Note that `check_signature_condition` does not verify signatures. + /// Signature verification is done when transaction transit from `SignedTransaction` to `AcceptedTransaction` state. + /// /// # Errors /// - Account not found - /// - Signature verification fails pub fn check_signature_condition(&self, wsv: &WorldStateView) -> Result> { let account_id = &self.payload.account_id; diff --git a/docs/source/references/config.md b/docs/source/references/config.md index e354cdee0f6..46799869ebf 100644 --- a/docs/source/references/config.md +++ b/docs/source/references/config.md @@ -69,6 +69,7 @@ The following is the default configuration used by Iroha. "QUEUE": { "MAXIMUM_TRANSACTIONS_IN_BLOCK": 512, "MAXIMUM_TRANSACTIONS_IN_QUEUE": 65536, + "MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER": 65536, "TRANSACTION_TIME_TO_LIVE_MS": 86400000, "FUTURE_THRESHOLD_MS": 1000 }, @@ -430,6 +431,7 @@ Has type `Option`[^1]. Can be configured via environm "FUTURE_THRESHOLD_MS": 1000, "MAXIMUM_TRANSACTIONS_IN_BLOCK": 512, "MAXIMUM_TRANSACTIONS_IN_QUEUE": 65536, + "MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER": 65536, "TRANSACTION_TIME_TO_LIVE_MS": 86400000 } ``` @@ -464,6 +466,16 @@ Has type `Option`[^1]. Can be configured via environment variable `QUEUE_MA 65536 ``` +### `queue.maximum_transactions_in_signature_buffer` + +The upper limit of the number of transactions waiting for more signatures. + +Has type `Option`[^1]. Can be configured via environment variable `QUEUE_MAXIMUM_TRANSACTIONS_IN_SIGNATURE_BUFFER` + +```json +65536 +``` + ### `queue.transaction_time_to_live_ms` The transaction will be dropped after this time if it is still in the queue. diff --git a/primitives/src/lib.rs b/primitives/src/lib.rs index 461ada2918a..3700788dab8 100644 --- a/primitives/src/lib.rs +++ b/primitives/src/lib.rs @@ -17,6 +17,7 @@ pub mod atomic; pub mod conststr; pub mod fixed; pub mod must_use; +pub mod riffle_iter; pub mod small; use fixed::prelude::*; diff --git a/primitives/src/riffle_iter.rs b/primitives/src/riffle_iter.rs new file mode 100644 index 00000000000..864f22a78c9 --- /dev/null +++ b/primitives/src/riffle_iter.rs @@ -0,0 +1,142 @@ +//! Contains riffle iterator and related trait + +/// Iterator which combine two iterators into the single one. +/// Name is inspired by riffle shuffle of cards deck. +/// +/// ```ignore +/// [(a0,a1,a2,..),(b0,b1,b2,..)] -> (a0,b0,a1,b1,a2,b2,..) +/// ``` +pub struct RiffleIterator { + left_iter: A, + right_iter: B, + state: RiffleState, +} + +enum RiffleState { + CurrentLeft, + CurrentRight, + LeftExhausted, + RightExhausted, + BothExhausted, +} + +/// Trait to create [`RiffleIterator`] from two iterators. +pub trait RiffleIter: Iterator + Sized { + /// Create `RoundRobinIterator` from two iterators. + fn riffle>( + self, + iter: I, + ) -> RiffleIterator::IntoIter> { + RiffleIterator { + left_iter: self, + right_iter: iter.into_iter(), + state: RiffleState::CurrentLeft, + } + } +} + +impl RiffleIter for T {} + +impl Iterator for RiffleIterator +where + A: Iterator, + B: Iterator, +{ + type Item = T; + + fn next(&mut self) -> Option { + use RiffleState::*; + loop { + match self.state { + BothExhausted => break None, + LeftExhausted => { + let item = self.right_iter.next(); + if item.is_none() { + self.state = BothExhausted; + } + break item; + } + RightExhausted => { + let item = self.left_iter.next(); + if item.is_none() { + self.state = BothExhausted; + } + break item; + } + CurrentLeft => { + let item = self.left_iter.next(); + if item.is_none() { + self.state = LeftExhausted; + continue; + } + self.state = CurrentRight; + break item; + } + CurrentRight => { + let item = self.right_iter.next(); + if item.is_none() { + self.state = RightExhausted; + continue; + } + self.state = CurrentLeft; + break item; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn riffle_iter_a_eq_b_size() { + let a = vec![0, 2, 4, 6, 8]; + let b = vec![1, 3, 5, 7, 9]; + assert_eq!( + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + a.into_iter().riffle(b).collect::>() + ); + } + + #[test] + fn riffle_iter_a_gt_b_size() { + let a = vec![0, 2, 4, 6, 8]; + let b = vec![1, 3, 5]; + assert_eq!( + vec![0, 1, 2, 3, 4, 5, 6, 8], + a.into_iter().riffle(b).collect::>() + ); + } + + #[test] + fn riffle_iter_a_lt_b_size() { + let a = vec![0, 2, 4]; + let b = vec![1, 3, 5, 7, 9]; + assert_eq!( + vec![0, 1, 2, 3, 4, 5, 7, 9], + a.into_iter().riffle(b).collect::>() + ); + } + + #[test] + fn riffle_iter_a_empty() { + let a = vec![0, 2, 4, 6, 8]; + let b = Vec::new(); + assert_eq!( + vec![0, 2, 4, 6, 8], + a.into_iter().riffle(b).collect::>() + ); + } + + #[test] + fn riffle_iter_b_empty() { + let a = Vec::new(); + let b = vec![1, 3, 5, 7, 9]; + assert_eq!( + vec![1, 3, 5, 7, 9], + a.into_iter().riffle(b).collect::>() + ); + } +}