From 6b6b3109caff30c3995fdbb090e333bb94f98b9d Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 5 Sep 2024 14:02:49 -0600 Subject: [PATCH] zcash_client_sqlite: Generalize the test framework to enable it to be moved to `zcash_client_backend` --- zcash_client_sqlite/Cargo.toml | 1 + zcash_client_sqlite/src/lib.rs | 144 ++++--- zcash_client_sqlite/src/testing.rs | 369 +++++++++++------- zcash_client_sqlite/src/testing/db.rs | 116 ++++++ zcash_client_sqlite/src/testing/pool.rs | 255 ++++++------ zcash_client_sqlite/src/wallet.rs | 36 +- zcash_client_sqlite/src/wallet/init.rs | 18 +- zcash_client_sqlite/src/wallet/orchard.rs | 60 +-- zcash_client_sqlite/src/wallet/sapling.rs | 55 +-- zcash_client_sqlite/src/wallet/scanning.rs | 117 +++--- zcash_client_sqlite/src/wallet/transparent.rs | 28 +- 11 files changed, 742 insertions(+), 457 deletions(-) create mode 100644 zcash_client_sqlite/src/testing/db.rs diff --git a/zcash_client_sqlite/Cargo.toml b/zcash_client_sqlite/Cargo.toml index 25f2288c51..c795b31928 100644 --- a/zcash_client_sqlite/Cargo.toml +++ b/zcash_client_sqlite/Cargo.toml @@ -79,6 +79,7 @@ document-features.workspace = true maybe-rayon.workspace = true [dev-dependencies] +ambassador.workspace = true assert_matches.workspace = true bls12_381.workspace = true incrementalmerkletree = { workspace = true, features = ["test-dependencies"] } diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 52f136cbbd..0e0695c17b 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -264,28 +264,36 @@ impl, P: consensus::Parameters> InputSource for &self, account: AccountId, target_value: NonNegativeAmount, - _sources: &[ShieldedProtocol], + sources: &[ShieldedProtocol], anchor_height: BlockHeight, exclude: &[Self::NoteRef], ) -> Result, Self::Error> { Ok(SpendableNotes::new( - wallet::sapling::select_spendable_sapling_notes( - self.conn.borrow(), - &self.params, - account, - target_value, - anchor_height, - exclude, - )?, + if sources.contains(&ShieldedProtocol::Sapling) { + wallet::sapling::select_spendable_sapling_notes( + self.conn.borrow(), + &self.params, + account, + target_value, + anchor_height, + exclude, + )? + } else { + vec![] + }, #[cfg(feature = "orchard")] - wallet::orchard::select_spendable_orchard_notes( - self.conn.borrow(), - &self.params, - account, - target_value, - anchor_height, - exclude, - )?, + if sources.contains(&ShieldedProtocol::Orchard) { + wallet::orchard::select_spendable_orchard_notes( + self.conn.borrow(), + &self.params, + account, + target_value, + anchor_height, + exclude, + )? + } else { + vec![] + }, )) } @@ -1687,10 +1695,11 @@ mod tests { }; use zcash_keys::keys::{UnifiedFullViewingKey, UnifiedSpendingKey}; use zcash_primitives::block::BlockHash; + use zcash_protocol::consensus; use crate::{ error::SqliteClientError, - testing::{TestBuilder, TestState}, + testing::{db::TestDbFactory, TestBuilder, TestState}, AccountId, DEFAULT_UA_REQUEST, }; @@ -1703,13 +1712,14 @@ mod tests { #[test] fn validate_seed() { let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().unwrap(); assert!({ st.wallet() - .validate_seed(account.account_id(), st.test_seed().unwrap()) + .validate_seed(account.id(), st.test_seed().unwrap()) .unwrap() }); @@ -1724,7 +1734,7 @@ mod tests { // check that passing an invalid seed results in a failure assert!({ !st.wallet() - .validate_seed(account.account_id(), &SecretVec::new(vec![1u8; 32])) + .validate_seed(account.id(), &SecretVec::new(vec![1u8; 32])) .unwrap() }); } @@ -1732,33 +1742,29 @@ mod tests { #[test] pub(crate) fn get_next_available_address() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let current_addr = st - .wallet() - .get_current_address(account.account_id()) - .unwrap(); + let current_addr = st.wallet().get_current_address(account.id()).unwrap(); assert!(current_addr.is_some()); let addr2 = st .wallet_mut() - .get_next_available_address(account.account_id(), DEFAULT_UA_REQUEST) + .get_next_available_address(account.id(), DEFAULT_UA_REQUEST) .unwrap(); assert!(addr2.is_some()); assert_ne!(current_addr, addr2); - let addr2_cur = st - .wallet() - .get_current_address(account.account_id()) - .unwrap(); + let addr2_cur = st.wallet().get_current_address(account.id()).unwrap(); assert_eq!(addr2, addr2_cur); } #[test] pub(crate) fn import_account_hd_0() { let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .set_account_index(zip32::AccountId::ZERO) .build(); @@ -1769,10 +1775,12 @@ mod tests { #[test] pub(crate) fn import_account_hd_1_then_2() { - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let birthday = AccountBirthday::from_parts( - ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])), + ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])), None, ); @@ -1797,15 +1805,19 @@ mod tests { AccountSource::Derived { seed_fingerprint: _, account_index } if account_index == zip32_index_2); } - fn check_collisions( - st: &mut TestState, + fn check_collisions( + st: &mut TestState, ufvk: &UnifiedFullViewingKey, birthday: &AccountBirthday, - existing_id: AccountId, - ) { + _existing_id: AccountId, + ) where + DbT::Account: core::fmt::Debug, + { assert_matches!( - st.wallet_mut().import_account_ufvk(ufvk, birthday, AccountPurpose::Spending), - Err(SqliteClientError::AccountCollision(id)) if id == existing_id); + st.wallet_mut() + .import_account_ufvk(ufvk, birthday, AccountPurpose::Spending), + Err(_) + ); // Remove the transparent component so that we don't have a match on the full UFVK. // That should still produce an AccountCollision error. @@ -1820,8 +1832,13 @@ mod tests { ) .unwrap(); assert_matches!( - st.wallet_mut().import_account_ufvk(&subset_ufvk, birthday, AccountPurpose::Spending), - Err(SqliteClientError::AccountCollision(id)) if id == existing_id); + st.wallet_mut().import_account_ufvk( + &subset_ufvk, + birthday, + AccountPurpose::Spending + ), + Err(_) + ); } // Remove the Orchard component so that we don't have a match on the full UFVK. @@ -1837,17 +1854,24 @@ mod tests { ) .unwrap(); assert_matches!( - st.wallet_mut().import_account_ufvk(&subset_ufvk, birthday, AccountPurpose::Spending), - Err(SqliteClientError::AccountCollision(id)) if id == existing_id); + st.wallet_mut().import_account_ufvk( + &subset_ufvk, + birthday, + AccountPurpose::Spending + ), + Err(_) + ); } } #[test] pub(crate) fn import_account_hd_1_then_conflicts() { - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let birthday = AccountBirthday::from_parts( - ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])), + ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])), None, ); @@ -1869,18 +1893,19 @@ mod tests { #[test] pub(crate) fn import_account_ufvk_then_conflicts() { - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let birthday = AccountBirthday::from_parts( - ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])), + ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])), None, ); let seed = Secret::new(vec![0u8; 32]); let zip32_index_0 = zip32::AccountId::ZERO; - let usk = - UnifiedSpendingKey::from_seed(&st.wallet().params, seed.expose_secret(), zip32_index_0) - .unwrap(); + let usk = UnifiedSpendingKey::from_seed(st.network(), seed.expose_secret(), zip32_index_0) + .unwrap(); let ufvk = usk.to_unified_full_viewing_key(); let account = st @@ -1888,8 +1913,8 @@ mod tests { .import_account_ufvk(&ufvk, &birthday, AccountPurpose::Spending) .unwrap(); assert_eq!( - ufvk.encode(&st.wallet().params), - account.ufvk().unwrap().encode(&st.wallet().params) + ufvk.encode(st.network()), + account.ufvk().unwrap().encode(st.network()) ); assert_matches!( @@ -1908,10 +1933,12 @@ mod tests { #[test] pub(crate) fn create_account_then_conflicts() { - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let birthday = AccountBirthday::from_parts( - ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])), + ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])), None, ); @@ -1933,6 +1960,7 @@ mod tests { fn transparent_receivers() { // Add an account to the wallet. let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -1940,10 +1968,7 @@ mod tests { let ufvk = account.usk().to_unified_full_viewing_key(); let (taddr, _) = account.usk().default_transparent_address(); - let receivers = st - .wallet() - .get_transparent_receivers(account.account_id()) - .unwrap(); + let receivers = st.wallet().get_transparent_receivers(account.id()).unwrap(); // The receiver for the default UA should be in the set. assert!(receivers.contains_key( @@ -1964,7 +1989,10 @@ mod tests { use zcash_primitives::consensus::NetworkConstants; use zcash_primitives::zip32; - let mut st = TestBuilder::new().with_fs_block_cache().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .with_fs_block_cache() + .build(); // The BlockMeta DB starts off empty. assert_eq!(st.cache().get_max_cached_height().unwrap(), None); @@ -1972,7 +2000,7 @@ mod tests { // Generate some fake CompactBlocks. let seed = [0u8; 32]; let hd_account_index = zip32::AccountId::ZERO; - let extsk = sapling::spending_key(&seed, st.wallet().params.coin_type(), hd_account_index); + let extsk = sapling::spending_key(&seed, st.network().coin_type(), hd_account_index); let dfvk = extsk.to_diversifiable_full_viewing_key(); let (h1, meta1, _) = st.generate_next_block( &dfvk, diff --git a/zcash_client_sqlite/src/testing.rs b/zcash_client_sqlite/src/testing.rs index 2c64592442..2d864b9224 100644 --- a/zcash_client_sqlite/src/testing.rs +++ b/zcash_client_sqlite/src/testing.rs @@ -15,6 +15,7 @@ use rusqlite::{params, Connection}; use secrecy::{Secret, SecretVec}; use shardtree::error::ShardTreeError; +use subtle::ConditionallySelectable; use tempfile::NamedTempFile; #[cfg(feature = "unstable")] @@ -26,7 +27,7 @@ use sapling::{ zip32::DiversifiableFullViewingKey, Note, Nullifier, }; -use zcash_client_backend::data_api::Account as AccountTrait; +use zcash_client_backend::data_api::{Account, InputSource}; #[allow(deprecated)] use zcash_client_backend::{ address::Address, @@ -74,8 +75,7 @@ use crate::{ chain::init::init_cache_database, error::SqliteClientError, wallet::{ - commitment_tree, get_wallet_summary, init::init_wallet_db, sapling::tests::test_prover, - Account, SubtreeScanProgress, + commitment_tree, get_wallet_summary, sapling::tests::test_prover, SubtreeScanProgress, }, AccountId, ReceivedNoteId, WalletDb, }; @@ -102,6 +102,7 @@ use crate::{ FsBlockDb, }; +pub(crate) mod db; pub(crate) mod pool; pub(crate) struct InitialChainState { @@ -111,17 +112,29 @@ pub(crate) struct InitialChainState { pub(crate) prior_orchard_roots: Vec>, } +pub(crate) trait DataStoreFactory { + type Error: core::fmt::Debug; + type AccountId: ConditionallySelectable + Default + Send + 'static; + type DataStore: InputSource + + WalletRead + + WalletWrite + + WalletCommitmentTrees; + + fn new_data_store(&self, network: LocalNetwork) -> Result; +} + /// A builder for a `zcash_client_sqlite` test. -pub(crate) struct TestBuilder { +pub(crate) struct TestBuilder { rng: ChaChaRng, network: LocalNetwork, cache: Cache, + ds_factory: DataStoreFactory, initial_chain_state: Option, account_birthday: Option, account_index: Option, } -impl TestBuilder<()> { +impl TestBuilder<(), ()> { pub const DEFAULT_NETWORK: LocalNetwork = LocalNetwork { overwinter: Some(BlockHeight::from_u32(1)), sapling: Some(BlockHeight::from_u32(100_000)), @@ -134,7 +147,7 @@ impl TestBuilder<()> { z_future: None, }; - /// Constructs a new test. + /// Constructs a new test environment builder. pub(crate) fn new() -> Self { TestBuilder { rng: ChaChaRng::seed_from_u64(0), @@ -142,18 +155,22 @@ impl TestBuilder<()> { // We pick 100,000 to be large enough to handle any hard-coded test offsets. network: Self::DEFAULT_NETWORK, cache: (), + ds_factory: (), initial_chain_state: None, account_birthday: None, account_index: None, } } +} +impl TestBuilder<(), A> { /// Adds a [`BlockDb`] cache to the test. - pub(crate) fn with_block_cache(self) -> TestBuilder { + pub(crate) fn with_block_cache(self) -> TestBuilder { TestBuilder { rng: self.rng, network: self.network, cache: BlockCache::new(), + ds_factory: self.ds_factory, initial_chain_state: self.initial_chain_state, account_birthday: self.account_birthday, account_index: self.account_index, @@ -162,11 +179,12 @@ impl TestBuilder<()> { /// Adds a [`FsBlockDb`] cache to the test. #[cfg(feature = "unstable")] - pub(crate) fn with_fs_block_cache(self) -> TestBuilder { + pub(crate) fn with_fs_block_cache(self) -> TestBuilder { TestBuilder { rng: self.rng, network: self.network, cache: FsBlockCache::new(), + ds_factory: self.ds_factory, initial_chain_state: self.initial_chain_state, account_birthday: self.account_birthday, account_index: self.account_index, @@ -174,7 +192,24 @@ impl TestBuilder<()> { } } -impl TestBuilder { +impl TestBuilder { + pub(crate) fn with_data_store_factory( + self, + ds_factory: DsFactory, + ) -> TestBuilder { + TestBuilder { + rng: self.rng, + network: self.network, + cache: self.cache, + ds_factory, + initial_chain_state: self.initial_chain_state, + account_birthday: self.account_birthday, + account_index: self.account_index, + } + } +} + +impl TestBuilder { pub(crate) fn with_initial_chain_state( mut self, chain_state: impl FnOnce(&mut ChaChaRng, &LocalNetwork) -> InitialChainState, @@ -222,20 +257,19 @@ impl TestBuilder { self.account_index = Some(index); self } +} +impl TestBuilder { /// Builds the state for this test. - pub(crate) fn build(self) -> TestState { - let data_file = NamedTempFile::new().unwrap(); - let mut db_data = WalletDb::for_path(data_file.path(), self.network).unwrap(); - init_wallet_db(&mut db_data, None).unwrap(); - + pub(crate) fn build(self) -> TestState { let mut cached_blocks = BTreeMap::new(); + let mut wallet_data = self.ds_factory.new_data_store(self.network).unwrap(); if let Some(initial_state) = &self.initial_chain_state { - db_data + wallet_data .put_sapling_subtree_roots(0, &initial_state.prior_sapling_roots) .unwrap(); - db_data + wallet_data .with_sapling_tree_mut(|t| { t.insert_frontier( initial_state.chain_state.final_sapling_tree().clone(), @@ -249,10 +283,10 @@ impl TestBuilder { #[cfg(feature = "orchard")] { - db_data + wallet_data .put_orchard_subtree_roots(0, &initial_state.prior_orchard_roots) .unwrap(); - db_data + wallet_data .with_orchard_tree_mut(|t| { t.insert_frontier( initial_state.chain_state.final_orchard_tree().clone(), @@ -285,10 +319,15 @@ impl TestBuilder { let test_account = self.account_birthday.map(|birthday| { let seed = Secret::new(vec![0u8; 32]); let (account, usk) = match self.account_index { - Some(index) => db_data.import_account_hd(&seed, index, &birthday).unwrap(), + Some(index) => wallet_data + .import_account_hd(&seed, index, &birthday) + .unwrap(), None => { - let result = db_data.create_account(&seed, &birthday).unwrap(); - (db_data.get_account(result.0).unwrap().unwrap(), result.1) + let result = wallet_data.create_account(&seed, &birthday).unwrap(); + ( + wallet_data.get_account(result.0).unwrap().unwrap(), + result.1, + ) } }; ( @@ -307,8 +346,8 @@ impl TestBuilder { latest_block_height: self .initial_chain_state .map(|s| s.chain_state.block_height()), - _data_file: data_file, - db_data, + wallet_data, + network: self.network, test_account, rng: self.rng, } @@ -395,21 +434,17 @@ impl CachedBlock { } #[derive(Clone)] -pub(crate) struct TestAccount { - account: Account, +pub(crate) struct TestAccount { + account: A, usk: UnifiedSpendingKey, birthday: AccountBirthday, } -impl TestAccount { - pub(crate) fn account(&self) -> &Account { +impl TestAccount { + pub(crate) fn account(&self) -> &A { &self.account } - pub(crate) fn account_id(&self) -> AccountId { - self.account.id() - } - pub(crate) fn usk(&self) -> &UnifiedSpendingKey { &self.usk } @@ -419,19 +454,119 @@ impl TestAccount { } } +impl Account for TestAccount { + type AccountId = A::AccountId; + + fn id(&self) -> Self::AccountId { + self.account.id() + } + + fn source(&self) -> data_api::AccountSource { + self.account.source() + } + + fn ufvk(&self) -> Option<&zcash_keys::keys::UnifiedFullViewingKey> { + self.account.ufvk() + } + + fn uivk(&self) -> zcash_keys::keys::UnifiedIncomingViewingKey { + self.account.uivk() + } +} + +pub(crate) trait Reset: WalletRead + Sized { + type Handle; + + fn reset(st: &mut TestState) -> Self::Handle; +} + /// The state for a `zcash_client_sqlite` test. -pub(crate) struct TestState { +pub(crate) struct TestState { cache: Cache, cached_blocks: BTreeMap, latest_block_height: Option, - _data_file: NamedTempFile, - db_data: WalletDb, - test_account: Option<(SecretVec, TestAccount)>, + wallet_data: DataStore, + network: Network, + test_account: Option<(SecretVec, TestAccount)>, rng: ChaChaRng, } -impl TestState +impl TestState { + /// Exposes an immutable reference to the test's `DataStore`. + pub(crate) fn wallet(&self) -> &DataStore { + &self.wallet_data + } + + /// Exposes a mutable reference to the test's `DataStore`. + pub(crate) fn wallet_mut(&mut self) -> &mut DataStore { + &mut self.wallet_data + } + + /// Exposes the test framework's source of randomness. + pub(crate) fn rng_mut(&mut self) -> &mut ChaChaRng { + &mut self.rng + } + + /// Exposes the network in use. + pub(crate) fn network(&self) -> &Network { + &self.network + } +} + +impl + TestState +{ + /// Convenience method for obtaining the Sapling activation height for the network under test. + pub(crate) fn sapling_activation_height(&self) -> BlockHeight { + self.network + .activation_height(NetworkUpgrade::Sapling) + .expect("Sapling activation height must be known.") + } + + /// Convenience method for obtaining the NU5 activation height for the network under test. + #[allow(dead_code)] + pub(crate) fn nu5_activation_height(&self) -> BlockHeight { + self.network + .activation_height(NetworkUpgrade::Nu5) + .expect("NU5 activation height must be known.") + } + + /// Exposes the test seed, if enabled via [`TestBuilder::with_test_account`]. + pub(crate) fn test_seed(&self) -> Option<&SecretVec> { + self.test_account.as_ref().map(|(seed, _)| seed) + } +} + +impl TestState where + Network: consensus::Parameters, + DataStore: WalletRead, +{ + /// Exposes the test account, if enabled via [`TestBuilder::with_test_account`]. + pub(crate) fn test_account(&self) -> Option<&TestAccount<::Account>> { + self.test_account.as_ref().map(|(_, acct)| acct) + } + + /// Exposes the test account's Sapling DFVK, if enabled via [`TestBuilder::with_test_account`]. + pub(crate) fn test_account_sapling(&self) -> Option<&DiversifiableFullViewingKey> { + let (_, acct) = self.test_account.as_ref()?; + let ufvk = acct.ufvk()?; + ufvk.sapling() + } + + /// Exposes the test account's Sapling DFVK, if enabled via [`TestBuilder::with_test_account`]. + #[cfg(feature = "orchard")] + pub(crate) fn test_account_orchard(&self) -> Option<&orchard::keys::FullViewingKey> { + let (_, acct) = self.test_account.as_ref()?; + let ufvk = acct.ufvk()?; + ufvk.orchard() + } +} + +impl TestState +where + Network: consensus::Parameters, + DataStore: WalletWrite, ::Error: fmt::Debug, { /// Exposes an immutable reference to the test's [`BlockSource`]. @@ -461,7 +596,6 @@ where ); self.cache.insert(&compact_block) } - /// Creates a fake block at the expected next height containing a single output of the /// given value, and inserts it into the cache. pub(crate) fn generate_next_block( @@ -613,7 +747,7 @@ where } let (cb, nfs) = fake_compact_block( - &self.network(), + &self.network, height, prev_hash, outputs, @@ -645,7 +779,7 @@ where let height = prior_cached_block.height() + 1; let cb = fake_compact_block_spending( - &self.network(), + &self.network, height, prior_cached_block.chain_state.block_hash(), note, @@ -717,7 +851,16 @@ where (height, res) } +} +impl TestState +where + Cache: TestCache, + ::Error: fmt::Debug, + ParamsT: consensus::Parameters + Send + 'static, + DbT: WalletWrite, + ::AccountId: ConditionallySelectable + Default + Send + 'static, +{ /// Invokes [`scan_cached_blocks`] with the given arguments, expecting success. pub(crate) fn scan_cached_blocks( &mut self, @@ -736,10 +879,7 @@ where limit: usize, ) -> Result< ScanSummary, - data_api::chain::error::Error< - SqliteClientError, - ::Error, - >, + data_api::chain::error::Error::Error>, > { let prior_cached_block = self .latest_cached_block_below_height(from_height) @@ -747,30 +887,28 @@ where .unwrap_or_else(|| CachedBlock::none(from_height - 1)); let result = scan_cached_blocks( - &self.network(), + &self.network, self.cache.block_source(), - &mut self.db_data, + &mut self.wallet_data, from_height, &prior_cached_block.chain_state, limit, ); result } +} +impl TestState { /// Resets the wallet using a new wallet database but with the same cache of blocks, /// and returns the old wallet database file. /// /// This does not recreate accounts, nor does it rescan the cached blocks. /// The resulting wallet has no test account. /// Before using any `generate_*` method on the reset state, call `reset_latest_cached_block()`. - pub(crate) fn reset(&mut self) -> NamedTempFile { - let network = self.network(); + pub(crate) fn reset(&mut self) -> DbT::Handle { self.latest_block_height = None; - let tf = std::mem::replace(&mut self._data_file, NamedTempFile::new().unwrap()); - self.db_data = WalletDb::for_path(self._data_file.path(), network).unwrap(); self.test_account = None; - init_wallet_db(&mut self.db_data, None).unwrap(); - tf + DbT::reset(self) } // /// Reset the latest cached block to the most recent one in the cache database. @@ -792,69 +930,7 @@ where // } } -impl TestState { - /// Exposes an immutable reference to the test's [`WalletDb`]. - pub(crate) fn wallet(&self) -> &WalletDb { - &self.db_data - } - - /// Exposes a mutable reference to the test's [`WalletDb`]. - pub(crate) fn wallet_mut(&mut self) -> &mut WalletDb { - &mut self.db_data - } - - /// Exposes the test framework's source of randomness. - pub(crate) fn rng_mut(&mut self) -> &mut ChaChaRng { - &mut self.rng - } - - /// Exposes the network in use. - pub(crate) fn network(&self) -> LocalNetwork { - self.db_data.params - } - - /// Convenience method for obtaining the Sapling activation height for the network under test. - pub(crate) fn sapling_activation_height(&self) -> BlockHeight { - self.db_data - .params - .activation_height(NetworkUpgrade::Sapling) - .expect("Sapling activation height must be known.") - } - - /// Convenience method for obtaining the NU5 activation height for the network under test. - #[allow(dead_code)] - pub(crate) fn nu5_activation_height(&self) -> BlockHeight { - self.db_data - .params - .activation_height(NetworkUpgrade::Nu5) - .expect("NU5 activation height must be known.") - } - - /// Exposes the test seed, if enabled via [`TestBuilder::with_test_account`]. - pub(crate) fn test_seed(&self) -> Option<&SecretVec> { - self.test_account.as_ref().map(|(seed, _)| seed) - } - - /// Exposes the test account, if enabled via [`TestBuilder::with_test_account`]. - pub(crate) fn test_account(&self) -> Option<&TestAccount> { - self.test_account.as_ref().map(|(_, acct)| acct) - } - - /// Exposes the test account's Sapling DFVK, if enabled via [`TestBuilder::with_test_account`]. - pub(crate) fn test_account_sapling(&self) -> Option { - self.test_account - .as_ref() - .and_then(|(_, acct)| acct.usk.to_unified_full_viewing_key().sapling().cloned()) - } - - /// Exposes the test account's Sapling DFVK, if enabled via [`TestBuilder::with_test_account`]. - #[cfg(feature = "orchard")] - pub(crate) fn test_account_orchard(&self) -> Option { - self.test_account - .as_ref() - .and_then(|(_, acct)| acct.usk.to_unified_full_viewing_key().orchard().cloned()) - } - +impl TestState { /// Insert shard roots for both trees. pub(crate) fn put_subtree_roots( &mut self, @@ -896,11 +972,10 @@ impl TestState { Zip317FeeError, >, > { - let params = self.network(); let prover = test_prover(); create_spend_to_address( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, &prover, &prover, usk, @@ -936,11 +1011,10 @@ impl TestState { InputsT: InputSelector>, { #![allow(deprecated)] - let params = self.network(); let prover = test_prover(); spend( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, &prover, &prover, input_selector, @@ -971,10 +1045,9 @@ impl TestState { where InputsT: InputSelector>, { - let params = self.network(); propose_transfer::<_, _, _, Infallible>( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, spend_from_account, input_selector, request, @@ -1004,10 +1077,9 @@ impl TestState { Zip317FeeError, >, > { - let params = self.network(); let result = propose_standard_transfer_to_address::<_, _, CommitmentTreeErrT>( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, fee_rule, spend_from_account, min_confirmations, @@ -1019,7 +1091,7 @@ impl TestState { ); if let Ok(proposal) = &result { - check_proposal_serialization_roundtrip(self.wallet(), proposal); + check_proposal_serialization_roundtrip(self.wallet_data.db(), proposal); } result @@ -1047,10 +1119,9 @@ impl TestState { where InputsT: ShieldingSelector>, { - let params = self.network(); propose_shielding::<_, _, _, Infallible>( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, input_selector, shielding_threshold, from_addrs, @@ -1076,11 +1147,10 @@ impl TestState { where FeeRuleT: FeeRule, { - let params = self.network(); let prover = test_prover(); create_proposed_transactions( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, &prover, &prover, usk, @@ -1111,11 +1181,10 @@ impl TestState { where InputsT: ShieldingSelector>, { - let params = self.network(); let prover = test_prover(); shield_transparent_funds( - &mut self.db_data, - ¶ms, + self.wallet_data.db_mut(), + &self.network, &prover, &prover, input_selector, @@ -1177,8 +1246,8 @@ impl TestState { min_confirmations: u32, ) -> Option> { get_wallet_summary( - &self.wallet().conn.unchecked_transaction().unwrap(), - &self.wallet().params, + &self.wallet().conn().unchecked_transaction().unwrap(), + &self.network, min_confirmations, &SubtreeScanProgress, ) @@ -1199,7 +1268,7 @@ impl TestState { pub(crate) fn get_tx_history( &self, ) -> Result>, SqliteClientError> { - let mut stmt = self.wallet().conn.prepare_cached( + let mut stmt = self.wallet().conn().prepare_cached( "SELECT * FROM v_transactions ORDER BY mined_height DESC, tx_index DESC", @@ -1239,7 +1308,7 @@ impl TestState { pub(crate) fn get_checkpoint_history( &self, ) -> Result)>, SqliteClientError> { - let mut stmt = self.wallet().conn.prepare_cached( + let mut stmt = self.wallet().conn().prepare_cached( "SELECT checkpoint_id, 2 AS pool, position FROM sapling_tree_checkpoints UNION SELECT checkpoint_id, 3 AS pool, position FROM orchard_tree_checkpoints @@ -1276,7 +1345,10 @@ impl TestState { pub(crate) fn dump_table(&self, name: &'static str) { assert!(name.chars().all(|c| c.is_ascii_alphabetic() || c == '_')); unsafe { - run_sqlite3(self._data_file.path(), &format!(r#".dump "{name}""#)); + run_sqlite3( + self.wallet_data.data_file().path(), + &format!(r#".dump "{name}""#), + ); } } @@ -1290,7 +1362,7 @@ impl TestState { #[allow(dead_code)] #[cfg(feature = "unstable")] pub(crate) unsafe fn run_sqlite3(&self, command: &str) { - run_sqlite3(self._data_file.path(), command) + run_sqlite3(self.wallet_data.data_file().path(), command) } } @@ -2145,14 +2217,11 @@ impl TestCache for FsBlockCache { } } -pub(crate) fn input_selector( +pub(crate) fn input_selector( fee_rule: StandardFeeRule, change_memo: Option<&str>, fallback_change_pool: ShieldedProtocol, -) -> GreedyInputSelector< - WalletDb, - standard::SingleOutputChangeStrategy, -> { +) -> GreedyInputSelector, standard::SingleOutputChangeStrategy> { let change_memo = change_memo.map(|m| MemoBytes::from(m.parse::().unwrap())); let change_strategy = standard::SingleOutputChangeStrategy::new(fee_rule, change_memo, fallback_change_pool); @@ -2161,11 +2230,11 @@ pub(crate) fn input_selector( // Checks that a protobuf proposal serialized from the provided proposal value correctly parses to // the same proposal value. -fn check_proposal_serialization_roundtrip( - db_data: &WalletDb, +fn check_proposal_serialization_roundtrip( + wallet_data: &WalletDb, proposal: &Proposal, ) { let proposal_proto = proposal::Proposal::from_standard_proposal(proposal); - let deserialized_proposal = proposal_proto.try_into_standard_proposal(db_data); + let deserialized_proposal = proposal_proto.try_into_standard_proposal(wallet_data); assert_matches!(deserialized_proposal, Ok(r) if &r == proposal); } diff --git a/zcash_client_sqlite/src/testing/db.rs b/zcash_client_sqlite/src/testing/db.rs new file mode 100644 index 0000000000..682ed3b9d0 --- /dev/null +++ b/zcash_client_sqlite/src/testing/db.rs @@ -0,0 +1,116 @@ +use ambassador::Delegate; +use core::ops::Range; +use rusqlite::Connection; +use std::collections::HashMap; + +use std::num::NonZeroU32; + +use tempfile::NamedTempFile; + +use rusqlite::{self}; +use secrecy::SecretVec; +use shardtree::{error::ShardTreeError, ShardTree}; +use zip32::fingerprint::SeedFingerprint; + +use zcash_client_backend::{ + data_api::{ + chain::{ChainState, CommitmentTreeRoot}, + scanning::ScanRange, + *, + }, + keys::UnifiedFullViewingKey, + wallet::{Note, NoteId, ReceivedNote, WalletTransparentOutput}, + ShieldedProtocol, +}; +use zcash_keys::{ + address::UnifiedAddress, + keys::{UnifiedAddressRequest, UnifiedSpendingKey}, +}; +use zcash_primitives::{ + block::BlockHash, + legacy::TransparentAddress, + transaction::{ + components::{amount::NonNegativeAmount, OutPoint}, + Transaction, TxId, + }, +}; +use zcash_protocol::{consensus::BlockHeight, local_consensus::LocalNetwork, memo::Memo}; + +use crate::AccountId; +use crate::{wallet::init::init_wallet_db, TransparentAddressMetadata, WalletDb}; + +use super::{DataStoreFactory, Reset, TestState}; + +#[derive(Delegate)] +#[delegate(InputSource, target = "wallet_db")] +#[delegate(WalletRead, target = "wallet_db")] +#[delegate(WalletWrite, target = "wallet_db")] +#[delegate(WalletCommitmentTrees, target = "wallet_db")] +pub(crate) struct TestDb { + wallet_db: WalletDb, + data_file: NamedTempFile, +} + +impl TestDb { + pub(crate) fn from_parts( + wallet_db: WalletDb, + data_file: NamedTempFile, + ) -> Self { + Self { + wallet_db, + data_file, + } + } + + pub(crate) fn db(&self) -> &WalletDb { + &self.wallet_db + } + + pub(crate) fn db_mut(&mut self) -> &mut WalletDb { + &mut self.wallet_db + } + + pub(crate) fn conn(&self) -> &Connection { + &self.wallet_db.conn + } + + pub(crate) fn conn_mut(&mut self) -> &mut Connection { + &mut self.wallet_db.conn + } + + pub(crate) fn data_file(&self) -> &NamedTempFile { + &self.data_file + } + + pub(crate) fn take_data_file(self) -> NamedTempFile { + self.data_file + } +} + +pub(crate) struct TestDbFactory; + +impl DataStoreFactory for TestDbFactory { + type Error = (); + type AccountId = AccountId; + type DataStore = TestDb; + + fn new_data_store(&self, network: LocalNetwork) -> Result { + let data_file = NamedTempFile::new().unwrap(); + let mut db_data = WalletDb::for_path(data_file.path(), network).unwrap(); + init_wallet_db(&mut db_data, None).unwrap(); + Ok(TestDb::from_parts(db_data, data_file)) + } +} + +impl Reset for TestDb { + type Handle = NamedTempFile; + + fn reset(st: &mut TestState) -> NamedTempFile { + let network = *st.network(); + let old_db = std::mem::replace( + &mut st.wallet_data, + TestDbFactory.new_data_store(network).unwrap(), + ); + old_db.take_data_file() + } +} diff --git a/zcash_client_sqlite/src/testing/pool.rs b/zcash_client_sqlite/src/testing/pool.rs index 242601993a..dec2a07beb 100644 --- a/zcash_client_sqlite/src/testing/pool.rs +++ b/zcash_client_sqlite/src/testing/pool.rs @@ -37,7 +37,8 @@ use zcash_client_backend::{ decrypt_and_store_transaction, input_selection::{GreedyInputSelector, GreedyInputSelectorError}, }, - AccountBirthday, DecryptedTransaction, Ratio, WalletRead, WalletSummary, WalletWrite, + Account as _, AccountBirthday, DecryptedTransaction, InputSource, Ratio, + WalletCommitmentTrees, WalletRead, WalletSummary, WalletWrite, }, decrypt_transaction, fees::{fixed, standard, DustOutputPolicy}, @@ -47,16 +48,16 @@ use zcash_client_backend::{ zip321::{self, Payment, TransactionRequest}, ShieldedProtocol, }; -use zcash_protocol::consensus::BlockHeight; +use zcash_protocol::consensus::{self, BlockHeight}; use super::TestFvk; use crate::{ error::SqliteClientError, testing::{ - input_selector, AddressType, BlockCache, FakeCompactOutput, InitialChainState, TestBuilder, - TestState, + db::{TestDb, TestDbFactory}, + input_selector, AddressType, FakeCompactOutput, InitialChainState, TestBuilder, TestState, }, - wallet::{block_max_scanned, commitment_tree, parse_scope, truncate_to_height}, + wallet::{commitment_tree, parse_scope, truncate_to_height}, AccountId, NoteId, ReceivedNoteId, }; @@ -90,7 +91,9 @@ pub(crate) trait ShieldedPoolTester { type MerkleTreeHash; type Note; - fn test_account_fvk(st: &TestState) -> Self::Fvk; + fn test_account_fvk( + st: &TestState, + ) -> Self::Fvk; fn usk_to_sk(usk: &UnifiedSpendingKey) -> &Self::Sk; fn sk(seed: &[u8]) -> Self::Sk; fn sk_to_fvk(sk: &Self::Sk) -> Self::Fvk; @@ -114,21 +117,22 @@ pub(crate) trait ShieldedPoolTester { fn empty_tree_leaf() -> Self::MerkleTreeHash; fn empty_tree_root(level: Level) -> Self::MerkleTreeHash; - fn put_subtree_roots( - st: &mut TestState, + fn put_subtree_roots( + st: &mut TestState, start_index: u64, roots: &[CommitmentTreeRoot], - ) -> Result<(), ShardTreeError>; + ) -> Result<(), ShardTreeError<::Error>>; fn next_subtree_index(s: &WalletSummary) -> u64; - fn select_spendable_notes( - st: &TestState, - account: AccountId, + #[allow(clippy::type_complexity)] + fn select_spendable_notes( + st: &TestState, + account: ::AccountId, target_value: NonNegativeAmount, anchor_height: BlockHeight, - exclude: &[ReceivedNoteId], - ) -> Result>, SqliteClientError>; + exclude: &[DbT::NoteRef], + ) -> Result>, ::Error>; fn decrypted_pool_outputs_count(d_tx: &DecryptedTransaction<'_, AccountId>) -> usize; @@ -137,8 +141,8 @@ pub(crate) trait ShieldedPoolTester { f: impl FnMut(&MemoBytes), ); - fn try_output_recovery( - st: &TestState, + fn try_output_recovery( + params: &P, height: BlockHeight, tx: &Transaction, fvk: &Self::Fvk, @@ -149,6 +153,7 @@ pub(crate) trait ShieldedPoolTester { pub(crate) fn send_single_step_proposed_transfer() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -162,11 +167,12 @@ pub(crate) fn send_single_step_proposed_transfer() { st.scan_cached_blocks(h, 1); // Spendable balance matches total balance - assert_eq!(st.get_total_balance(account.account_id()), value); - assert_eq!(st.get_spendable_balance(account.account_id(), 1), value); + assert_eq!(st.get_total_balance(account.id()), value); + assert_eq!(st.get_spendable_balance(account.id(), 1), value); assert_eq!( - block_max_scanned(&st.wallet().conn, &st.wallet().params) + st.wallet() + .block_max_scanned() .unwrap() .unwrap() .block_height(), @@ -176,7 +182,7 @@ pub(crate) fn send_single_step_proposed_transfer() { let to_extsk = T::sk(&[0xf5; 32]); let to: Address = T::sk_default_address(&to_extsk); let request = zip321::TransactionRequest::new(vec![Payment::without_memo( - to.to_zcash_address(&st.network()), + to.to_zcash_address(st.network()), NonNegativeAmount::const_from_u64(10000), )]) .unwrap(); @@ -193,7 +199,7 @@ pub(crate) fn send_single_step_proposed_transfer() { let proposal = st .propose_transfer( - account.account_id(), + account.id(), input_selector, request, NonZeroU32::new(1).unwrap(), @@ -215,13 +221,10 @@ pub(crate) fn send_single_step_proposed_transfer() { .get_transaction(sent_tx_id) .unwrap() .expect("Created transaction was stored."); - let ufvks = [( - account.account_id(), - account.usk().to_unified_full_viewing_key(), - )] - .into_iter() - .collect(); - let d_tx = decrypt_transaction(&st.network(), h + 1, &tx, &ufvks); + let ufvks = [(account.id(), account.usk().to_unified_full_viewing_key())] + .into_iter() + .collect(); + let d_tx = decrypt_transaction(st.network(), h + 1, &tx, &ufvks); assert_eq!(T::decrypted_pool_outputs_count(&d_tx), 2); let mut found_tx_change_memo = false; @@ -241,7 +244,7 @@ pub(crate) fn send_single_step_proposed_transfer() { let sent_note_ids = { let mut stmt_sent_notes = st .wallet() - .conn + .conn() .prepare( "SELECT output_index FROM sent_notes @@ -294,8 +297,9 @@ pub(crate) fn send_single_step_proposed_transfer() { let tx_history = st.get_tx_history().unwrap(); assert_eq!(tx_history.len(), 2); + let network = *st.network(); assert_matches!( - decrypt_and_store_transaction(&st.network(), st.wallet_mut(), &tx, None), + decrypt_and_store_transaction(&network, st.wallet_mut(), &tx, None), Ok(_) ); } @@ -321,21 +325,23 @@ pub(crate) fn send_multi_step_proposed_transfer() { }; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let (default_addr, default_index) = account.usk().default_transparent_address(); let dfvk = T::test_account_fvk(&st); - let add_funds = |st: &mut TestState<_>, value| { + let add_funds = |st: &mut TestState<_, TestDb, _>, value| { let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); st.scan_cached_blocks(h, 1); assert_eq!( - block_max_scanned(&st.wallet().conn, &st.wallet().params) + st.wallet() + .block_max_scanned() .unwrap() .unwrap() .block_height(), @@ -348,7 +354,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { let value = NonNegativeAmount::const_from_u64(100000); let transfer_amount = NonNegativeAmount::const_from_u64(50000); - let run_test = |st: &mut TestState<_>, expected_index| { + let run_test = |st: &mut TestState<_, TestDb, _>, expected_index| { // Add funds to the wallet. add_funds(st, value); @@ -407,7 +413,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { // Verify that the stored sent outputs match what we're expecting. let mut stmt_sent = st .wallet() - .conn + .conn() .prepare( "SELECT value, to_address, ephemeral_addresses.address, ephemeral_addresses.address_index FROM sent_notes @@ -459,7 +465,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { assert_matches!( confirmed_sent[1][0].clone(), (sent_v, sent_to_addr, None, None) - if sent_v == u64::try_from(transfer_amount).unwrap() && sent_to_addr == Some(tex_addr.encode(&st.wallet().params))); + if sent_v == u64::try_from(transfer_amount).unwrap() && sent_to_addr == Some(tex_addr.encode(st.network()))); // Check that the transaction history matches what we expect. let tx_history = st.get_tx_history().unwrap(); @@ -501,7 +507,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { let height = add_funds(&mut st, value); - let ephemeral_taddr = Address::decode(&st.wallet().params, &ephemeral0).expect("valid address"); + let ephemeral_taddr = Address::decode(st.network(), &ephemeral0).expect("valid address"); assert_matches!( ephemeral_taddr, Address::Transparent(TransparentAddress::PublicKeyHash(_)) @@ -555,7 +561,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { } let mut builder = Builder::new( - st.wallet().params, + *st.network(), height + 1, BuildConfig::Standard { sapling_anchor: None, @@ -614,7 +620,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { // We call get_wallet_transparent_output with `allow_unspendable = true` to verify // storage because the decrypted transaction has not yet been mined. let utxo = - get_wallet_transparent_output(&st.db_data.conn, &OutPoint::new(txid.into(), 0), true) + get_wallet_transparent_output(st.wallet().conn(), &OutPoint::new(txid.into(), 0), true) .unwrap(); assert_matches!(utxo, Some(v) if v.value() == utxo_value); @@ -626,7 +632,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { assert_eq!(new_known_addrs.len(), (GAP_LIMIT as usize) + 11); assert!(new_known_addrs.starts_with(&known_addrs)); - let reservation_should_succeed = |st: &mut TestState<_>, n| { + let reservation_should_succeed = |st: &mut TestState<_, TestDb, _>, n| { let reserved = st .wallet_mut() .reserve_next_n_ephemeral_addresses(account_id, n) @@ -634,7 +640,7 @@ pub(crate) fn send_multi_step_proposed_transfer() { assert_eq!(reserved.len(), n); reserved }; - let reservation_should_fail = |st: &mut TestState<_>, n, expected_bad_index| { + let reservation_should_fail = |st: &mut TestState<_, TestDb, _>, n, expected_bad_index| { assert_matches!(st .wallet_mut() .reserve_next_n_ephemeral_addresses(account_id, n), @@ -724,20 +730,22 @@ pub(crate) fn proposal_fails_if_not_all_ephemeral_outputs_consumed, value| { + let add_funds = |st: &mut TestState<_, TestDb, _>, value| { let (h, _, _) = st.generate_next_block(&dfvk, AddressType::DefaultExternal, value); st.scan_cached_blocks(h, 1); assert_eq!( - block_max_scanned(&st.wallet().conn, &st.wallet().params) + st.wallet() + .block_max_scanned() .unwrap() .unwrap() .block_height(), @@ -806,6 +814,7 @@ pub(crate) fn proposal_fails_if_not_all_ephemeral_outputs_consumed() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let dfvk = T::test_account_fvk(&st); @@ -813,7 +822,7 @@ pub(crate) fn create_to_address_fails_on_incorrect_usk() // Create a USK that doesn't exist in the wallet let acct1 = zip32::AccountId::try_from(1).unwrap(); - let usk1 = UnifiedSpendingKey::from_seed(&st.network(), &[1u8; 32], acct1).unwrap(); + let usk1 = UnifiedSpendingKey::from_seed(st.network(), &[1u8; 32], acct1).unwrap(); // Attempting to spend with a USK that is not in the wallet results in an error assert_matches!( @@ -834,10 +843,11 @@ pub(crate) fn create_to_address_fails_on_incorrect_usk() #[allow(deprecated)] pub(crate) fn proposal_fails_with_no_blocks() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); - let account_id = st.test_account().unwrap().account_id(); + let account_id = st.test_account().unwrap().id(); let dfvk = T::test_account_fvk(&st); let to = T::fvk_default_address(&dfvk); @@ -862,12 +872,13 @@ pub(crate) fn proposal_fails_with_no_blocks() { pub(crate) fn spend_fails_on_unverified_notes() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); // Add funds to the wallet in a single note @@ -1013,12 +1024,13 @@ pub(crate) fn spend_fails_on_unverified_notes() { pub(crate) fn spend_fails_on_locked_notes() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); let fee_rule = StandardFeeRule::Zip317; @@ -1148,12 +1160,13 @@ pub(crate) fn spend_fails_on_locked_notes() { pub(crate) fn ovk_policy_prevents_recovery_from_chain() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); // Add funds to the wallet in a single note @@ -1171,7 +1184,7 @@ pub(crate) fn ovk_policy_prevents_recovery_from_chain() { let fee_rule = StandardFeeRule::Zip317; #[allow(clippy::type_complexity)] - let send_and_recover_with_policy = |st: &mut TestState, + let send_and_recover_with_policy = |st: &mut TestState<_, TestDb, _>, ovk_policy| -> Result< Option<(Note, Address, MemoBytes)>, @@ -1200,17 +1213,16 @@ pub(crate) fn ovk_policy_prevents_recovery_from_chain() { // Fetch the transaction from the database let raw_tx: Vec<_> = st .wallet() - .conn + .conn() .query_row( - "SELECT raw FROM transactions - WHERE txid = ?", + "SELECT raw FROM transactions WHERE txid = ?", [txid.as_ref()], |row| row.get(0), ) .unwrap(); let tx = Transaction::read(&raw_tx[..], BranchId::Canopy).unwrap(); - T::try_output_recovery(st, h1, &tx, &dfvk) + T::try_output_recovery(st.network(), h1, &tx, &dfvk) }; // Send some of the funds to another address, keeping history. @@ -1241,12 +1253,13 @@ pub(crate) fn ovk_policy_prevents_recovery_from_chain() { pub(crate) fn spend_succeeds_to_t_addr_zero_change() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); // Add funds to the wallet in a single note @@ -1285,12 +1298,13 @@ pub(crate) fn spend_succeeds_to_t_addr_zero_change() { pub(crate) fn change_note_spends_succeed() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); // Add funds to the wallet in a single note owned by the internal spending key @@ -1309,7 +1323,7 @@ pub(crate) fn change_note_spends_succeed() { NonNegativeAmount::ZERO ); - let change_note_scope = st.wallet().conn.query_row( + let change_note_scope = st.wallet().conn().query_row( &format!( "SELECT recipient_key_scope FROM {}_received_notes @@ -1349,11 +1363,14 @@ pub(crate) fn change_note_spends_succeed() { pub(crate) fn external_address_change_spends_detected_in_restore_from_seed< T: ShieldedPoolTester, >() { - let mut st = TestBuilder::new().with_block_cache().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .with_block_cache() + .build(); // Add two accounts to the wallet. let seed = Secret::new([0u8; 32].to_vec()); - let birthday = AccountBirthday::from_sapling_activation(&st.network(), BlockHash([0; 32])); + let birthday = AccountBirthday::from_sapling_activation(st.network(), BlockHash([0; 32])); let (account_id, usk) = st.wallet_mut().create_account(&seed, &birthday).unwrap(); let dfvk = T::sk_to_fvk(T::usk_to_sk(&usk)); @@ -1376,9 +1393,9 @@ pub(crate) fn external_address_change_spends_detected_in_restore_from_seed< let addr2 = T::fvk_default_address(&dfvk2); let req = TransactionRequest::new(vec![ // payment to an external recipient - Payment::without_memo(addr2.to_zcash_address(&st.network()), amount_sent), + Payment::without_memo(addr2.to_zcash_address(st.network()), amount_sent), // payment back to the originating wallet, simulating legacy change - Payment::without_memo(addr.to_zcash_address(&st.network()), amount_legacy_change), + Payment::without_memo(addr.to_zcash_address(st.network()), amount_legacy_change), ]) .unwrap(); @@ -1437,12 +1454,13 @@ pub(crate) fn external_address_change_spends_detected_in_restore_from_seed< #[allow(dead_code)] pub(crate) fn zip317_spend() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let account = st.test_account().cloned().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let dfvk = T::test_account_fvk(&st); // Add funds to the wallet @@ -1472,7 +1490,7 @@ pub(crate) fn zip317_spend() { // This first request will fail due to insufficient non-dust funds let req = TransactionRequest::new(vec![Payment::without_memo( - T::fvk_default_address(&dfvk).to_zcash_address(&st.network()), + T::fvk_default_address(&dfvk).to_zcash_address(st.network()), NonNegativeAmount::const_from_u64(50000), )]) .unwrap(); @@ -1493,7 +1511,7 @@ pub(crate) fn zip317_spend() { // This request will succeed, spending a single dust input to pay the 10000 // ZAT fee in addition to the 41000 ZAT output to the recipient let req = TransactionRequest::new(vec![Payment::without_memo( - T::fvk_default_address(&dfvk).to_zcash_address(&st.network()), + T::fvk_default_address(&dfvk).to_zcash_address(st.network()), NonNegativeAmount::const_from_u64(41000), )]) .unwrap(); @@ -1523,6 +1541,7 @@ pub(crate) fn zip317_spend() { #[cfg(feature = "transparent-inputs")] pub(crate) fn shield_transparent() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -1532,7 +1551,7 @@ pub(crate) fn shield_transparent() { let uaddr = st .wallet() - .get_current_address(account.account_id()) + .get_current_address(account.id()) .unwrap() .unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -1596,6 +1615,7 @@ pub(crate) fn birthday_in_anchor_shard() { // notes beyond the end of the first shard. let frontier_tree_size: u32 = (0x1 << 16) + 1234; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { let birthday_height = network.activation_height(NetworkUpgrade::Nu5).unwrap() + 1000; @@ -1671,7 +1691,7 @@ pub(crate) fn birthday_in_anchor_shard() { // Verify that the received note is not considered spendable let account = st.test_account().unwrap(); - let account_id = account.account_id(); + let account_id = account.id(); let spendable = T::select_spendable_notes( &st, account_id, @@ -1701,6 +1721,7 @@ pub(crate) fn birthday_in_anchor_shard() { pub(crate) fn checkpoint_gaps() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -1738,14 +1759,14 @@ pub(crate) fn checkpoint_gaps() { // Fake that everything has been scanned st.wallet() - .conn + .conn() .execute_batch("UPDATE scan_queue SET priority = 10") .unwrap(); // Verify that our note is considered spendable let spendable = T::select_spendable_notes( &st, - account.account_id(), + account.id(), NonNegativeAmount::const_from_u64(300000), account.birthday().height() + 5, &[], @@ -1773,6 +1794,7 @@ pub(crate) fn checkpoint_gaps() { #[cfg(feature = "orchard")] pub(crate) fn pool_crossing_required() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) // TODO: Allow for Orchard // activation after Sapling @@ -1790,15 +1812,12 @@ pub(crate) fn pool_crossing_required() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) // TODO: Allow for Orchard // activation after Sapling @@ -1880,15 +1900,12 @@ pub(crate) fn fully_funded_fully_private() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) // TODO: Allow for Orchard // activation after Sapling @@ -1970,15 +1988,12 @@ pub(crate) fn fully_funded_send_to_t() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) // TODO: Allow for Orchard // activation after Sapling .build(); let account = st.test_account().cloned().unwrap(); - let acct_id = account.account_id(); + let acct_id = account.id(); let p0_fvk = P0::test_account_fvk(&st); let p1_fvk = P1::test_account_fvk(&st); @@ -2089,7 +2105,7 @@ pub(crate) fn multi_pool_checkpoint() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) // TODO: Allow for Orchard // activation after Sapling @@ -2253,6 +2270,7 @@ pub(crate) fn multi_pool_checkpoints_with_pruning< pub(crate) fn valid_chain_states() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2287,6 +2305,7 @@ pub(crate) fn valid_chain_states() { #[allow(dead_code)] pub(crate) fn invalid_chain_cache_disconnected() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2341,6 +2360,7 @@ pub(crate) fn invalid_chain_cache_disconnected() { pub(crate) fn data_db_truncation() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2362,46 +2382,46 @@ pub(crate) fn data_db_truncation() { // Spendable balance should reflect both received notes assert_eq!( - st.get_spendable_balance(account.account_id(), 1), + st.get_spendable_balance(account.id(), 1), (value + value2).unwrap() ); // "Rewind" to height of last scanned block (this is a no-op) st.wallet_mut() + .db_mut() .transactionally(|wdb| truncate_to_height(wdb.conn.0, &wdb.params, h + 1)) .unwrap(); // Spendable balance should be unaltered assert_eq!( - st.get_spendable_balance(account.account_id(), 1), + st.get_spendable_balance(account.id(), 1), (value + value2).unwrap() ); // Rewind so that one block is dropped st.wallet_mut() + .db_mut() .transactionally(|wdb| truncate_to_height(wdb.conn.0, &wdb.params, h)) .unwrap(); // Spendable balance should only contain the first received note; // the rest should be pending. - assert_eq!(st.get_spendable_balance(account.account_id(), 1), value); - assert_eq!( - st.get_pending_shielded_balance(account.account_id(), 1), - value2 - ); + assert_eq!(st.get_spendable_balance(account.id(), 1), value); + assert_eq!(st.get_pending_shielded_balance(account.id(), 1), value2); // Scan the cache again st.scan_cached_blocks(h, 2); // Account balance should again reflect both received notes assert_eq!( - st.get_spendable_balance(account.account_id(), 1), + st.get_spendable_balance(account.id(), 1), (value + value2).unwrap() ); } pub(crate) fn scan_cached_blocks_allows_blocks_out_of_order() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2412,7 +2432,7 @@ pub(crate) fn scan_cached_blocks_allows_blocks_out_of_order() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2479,7 +2500,7 @@ pub(crate) fn scan_cached_blocks_finds_received_notes() { assert_eq!(T::received_note_count(&summary), 1); // Account balance should reflect the received note - assert_eq!(st.get_total_balance(account.account_id()), value); + assert_eq!(st.get_total_balance(account.id()), value); // Create a second fake CompactBlock sending more value to the address let value2 = NonNegativeAmount::const_from_u64(7); @@ -2493,7 +2514,7 @@ pub(crate) fn scan_cached_blocks_finds_received_notes() { // Account balance should reflect both received notes assert_eq!( - st.get_total_balance(account.account_id()), + st.get_total_balance(account.id()), (value + value2).unwrap() ); } @@ -2501,6 +2522,7 @@ pub(crate) fn scan_cached_blocks_finds_received_notes() { // TODO: This test can probably be entirely removed, as the following test duplicates it entirely. pub(crate) fn scan_cached_blocks_finds_change_notes() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2520,7 +2542,7 @@ pub(crate) fn scan_cached_blocks_finds_change_notes() { st.scan_cached_blocks(received_height, 1); // Account balance should reflect the received note - assert_eq!(st.get_total_balance(account.account_id()), value); + assert_eq!(st.get_total_balance(account.id()), value); // Create a second fake CompactBlock spending value from the address let not_our_key = T::sk_to_fvk(&T::sk(&[0xf5; 32])); @@ -2533,13 +2555,14 @@ pub(crate) fn scan_cached_blocks_finds_change_notes() { // Account balance should equal the change assert_eq!( - st.get_total_balance(account.account_id()), + st.get_total_balance(account.id()), (value - value2).unwrap() ); } pub(crate) fn scan_cached_blocks_detects_spends_out_of_order() { let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -2566,7 +2589,7 @@ pub(crate) fn scan_cached_blocks_detects_spends_out_of_order(dsf: DsF) { let mut st = TestBuilder::new() + .with_data_store_factory(dsf) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); - let block_fully_scanned = |st: &TestState| { + let block_fully_scanned = |st: &TestState<_, DsF::DataStore, _>| { st.wallet() .block_fully_scanned() .unwrap() @@ -3316,13 +3319,14 @@ mod tests { #[test] fn test_account_birthday() { let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); - let account_id = st.test_account().unwrap().account_id(); + let account_id = st.test_account().unwrap().id(); assert_matches!( - account_birthday(&st.wallet().conn, account_id), + account_birthday(st.wallet().conn(), account_id), Ok(birthday) if birthday == st.sapling_activation_height() ) } diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index ca6c4b52ef..4eb3c6e2bd 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -429,7 +429,11 @@ mod tests { zip32::AccountId, }; - use crate::{testing::TestBuilder, wallet::db, WalletDb, UA_TRANSPARENT}; + use crate::{ + testing::{db::TestDbFactory, TestBuilder}, + wallet::db, + WalletDb, UA_TRANSPARENT, + }; use super::init_wallet_db; @@ -453,7 +457,9 @@ mod tests { #[test] fn verify_schema() { - let st = TestBuilder::new().build(); + let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); use regex::Regex; let re = Regex::new(r"\s+").unwrap(); @@ -489,7 +495,7 @@ mod tests { db::TABLE_TX_RETRIEVAL_QUEUE, ]; - let rows = describe_tables(&st.wallet().conn).unwrap(); + let rows = describe_tables(&st.wallet().db().conn).unwrap(); assert_eq!(rows.len(), expected_tables.len()); for (actual, expected) in rows.iter().zip(expected_tables.iter()) { assert_eq!( @@ -515,6 +521,7 @@ mod tests { ]; let mut indices_query = st .wallet() + .db() .conn .prepare("SELECT sql FROM sqlite_master WHERE type = 'index' AND sql != '' ORDER BY tbl_name, name") .unwrap(); @@ -530,12 +537,12 @@ mod tests { } let expected_views = vec![ - db::view_orchard_shard_scan_ranges(&st.network()), + db::view_orchard_shard_scan_ranges(st.network()), db::view_orchard_shard_unscanned_ranges(), db::VIEW_ORCHARD_SHARDS_SCAN_STATE.to_owned(), db::VIEW_RECEIVED_OUTPUT_SPENDS.to_owned(), db::VIEW_RECEIVED_OUTPUTS.to_owned(), - db::view_sapling_shard_scan_ranges(&st.network()), + db::view_sapling_shard_scan_ranges(st.network()), db::view_sapling_shard_unscanned_ranges(), db::VIEW_SAPLING_SHARDS_SCAN_STATE.to_owned(), db::VIEW_TRANSACTIONS.to_owned(), @@ -544,6 +551,7 @@ mod tests { let mut views_query = st .wallet() + .db() .conn .prepare("SELECT sql FROM sqlite_schema WHERE type = 'view' ORDER BY tbl_name") .unwrap(); diff --git a/zcash_client_sqlite/src/wallet/orchard.rs b/zcash_client_sqlite/src/wallet/orchard.rs index cbd8b23b4b..8402644be9 100644 --- a/zcash_client_sqlite/src/wallet/orchard.rs +++ b/zcash_client_sqlite/src/wallet/orchard.rs @@ -393,10 +393,12 @@ pub(crate) mod tests { note_encryption::OrchardDomain, tree::MerkleHashOrchard, }; + use shardtree::error::ShardTreeError; use zcash_client_backend::{ data_api::{ - chain::CommitmentTreeRoot, DecryptedTransaction, WalletCommitmentTrees, WalletSummary, + chain::CommitmentTreeRoot, DecryptedTransaction, InputSource, WalletCommitmentTrees, + WalletRead, WalletSummary, }, wallet::{Note, ReceivedNote}, }; @@ -406,17 +408,20 @@ pub(crate) mod tests { }; use zcash_note_encryption::try_output_recovery_with_ovk; use zcash_primitives::transaction::Transaction; - use zcash_protocol::{consensus::BlockHeight, memo::MemoBytes, ShieldedProtocol}; + use zcash_protocol::{ + consensus::{self, BlockHeight}, + memo::MemoBytes, + value::Zatoshis, + ShieldedProtocol, + }; - use super::select_spendable_orchard_notes; use crate::{ - error::SqliteClientError, testing::{ self, pool::{OutputRecoveryError, ShieldedPoolTester}, TestState, }, - wallet::{commitment_tree, sapling::tests::SaplingPoolTester}, + wallet::sapling::tests::SaplingPoolTester, ORCHARD_TABLES_PREFIX, }; @@ -431,8 +436,10 @@ pub(crate) mod tests { type MerkleTreeHash = MerkleHashOrchard; type Note = orchard::note::Note; - fn test_account_fvk(st: &TestState) -> Self::Fvk { - st.test_account_orchard().unwrap() + fn test_account_fvk( + st: &TestState, + ) -> Self::Fvk { + st.test_account_orchard().unwrap().clone() } fn usk_to_sk(usk: &UnifiedSpendingKey) -> &Self::Sk { @@ -479,11 +486,11 @@ pub(crate) mod tests { MerkleHashOrchard::empty_root(level) } - fn put_subtree_roots( - st: &mut TestState, + fn put_subtree_roots( + st: &mut TestState, start_index: u64, roots: &[CommitmentTreeRoot], - ) -> Result<(), ShardTreeError> { + ) -> Result<(), ShardTreeError<::Error>> { st.wallet_mut() .put_orchard_subtree_roots(start_index, roots) } @@ -492,22 +499,23 @@ pub(crate) mod tests { s.next_orchard_subtree_index() } - fn select_spendable_notes( - st: &TestState, - account: crate::AccountId, - target_value: zcash_protocol::value::Zatoshis, + fn select_spendable_notes( + st: &TestState, + account: ::AccountId, + target_value: Zatoshis, anchor_height: BlockHeight, - exclude: &[crate::ReceivedNoteId], - ) -> Result>, SqliteClientError> + exclude: &[DbT::NoteRef], + ) -> Result>, ::Error> { - select_spendable_orchard_notes( - &st.wallet().conn, - &st.wallet().params, - account, - target_value, - anchor_height, - exclude, - ) + st.wallet() + .select_spendable_notes( + account, + target_value, + &[ShieldedProtocol::Orchard], + anchor_height, + exclude, + ) + .map(|n| n.take_orchard()) } fn decrypted_pool_outputs_count( @@ -525,8 +533,8 @@ pub(crate) mod tests { } } - fn try_output_recovery( - _: &TestState, + fn try_output_recovery( + _params: &P, _: BlockHeight, tx: &Transaction, fvk: &Self::Fvk, diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index 1e51df25c3..36c893fe9a 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -401,6 +401,7 @@ pub(crate) fn put_received_note( #[cfg(test)] pub(crate) mod tests { use incrementalmerkletree::{Hashable, Level}; + use shardtree::error::ShardTreeError; use zcash_proofs::prover::LocalTxProver; @@ -423,22 +424,22 @@ pub(crate) mod tests { use zcash_client_backend::{ address::Address, data_api::{ - chain::CommitmentTreeRoot, DecryptedTransaction, WalletCommitmentTrees, WalletSummary, + chain::CommitmentTreeRoot, DecryptedTransaction, InputSource, WalletCommitmentTrees, + WalletRead, WalletSummary, }, keys::UnifiedSpendingKey, wallet::{Note, ReceivedNote}, ShieldedProtocol, }; + use zcash_protocol::consensus; use crate::{ - error::SqliteClientError, testing::{ self, pool::{OutputRecoveryError, ShieldedPoolTester}, TestState, }, - wallet::{commitment_tree, sapling::select_spendable_sapling_notes}, - AccountId, ReceivedNoteId, SAPLING_TABLES_PREFIX, + AccountId, SAPLING_TABLES_PREFIX, }; pub(crate) struct SaplingPoolTester; @@ -452,8 +453,10 @@ pub(crate) mod tests { type MerkleTreeHash = sapling::Node; type Note = sapling::Note; - fn test_account_fvk(st: &TestState) -> Self::Fvk { - st.test_account_sapling().unwrap() + fn test_account_fvk( + st: &TestState, + ) -> Self::Fvk { + st.test_account_sapling().unwrap().clone() } fn usk_to_sk(usk: &UnifiedSpendingKey) -> &Self::Sk { @@ -488,11 +491,11 @@ pub(crate) mod tests { sapling::Node::empty_root(level) } - fn put_subtree_roots( - st: &mut TestState, + fn put_subtree_roots( + st: &mut TestState, start_index: u64, roots: &[CommitmentTreeRoot], - ) -> Result<(), ShardTreeError> { + ) -> Result<(), ShardTreeError<::Error>> { st.wallet_mut() .put_sapling_subtree_roots(start_index, roots) } @@ -501,21 +504,23 @@ pub(crate) mod tests { s.next_sapling_subtree_index() } - fn select_spendable_notes( - st: &TestState, - account: AccountId, + fn select_spendable_notes( + st: &TestState, + account: ::AccountId, target_value: NonNegativeAmount, anchor_height: BlockHeight, - exclude: &[ReceivedNoteId], - ) -> Result>, SqliteClientError> { - select_spendable_sapling_notes( - &st.wallet().conn, - &st.wallet().params, - account, - target_value, - anchor_height, - exclude, - ) + exclude: &[DbT::NoteRef], + ) -> Result>, ::Error> + { + st.wallet() + .select_spendable_notes( + account, + target_value, + &[ShieldedProtocol::Sapling], + anchor_height, + exclude, + ) + .map(|n| n.take_sapling()) } fn decrypted_pool_outputs_count(d_tx: &DecryptedTransaction<'_, AccountId>) -> usize { @@ -531,8 +536,8 @@ pub(crate) mod tests { } } - fn try_output_recovery( - st: &TestState, + fn try_output_recovery( + params: &P, height: BlockHeight, tx: &Transaction, fvk: &Self::Fvk, @@ -542,7 +547,7 @@ pub(crate) mod tests { let result = try_sapling_output_recovery( &fvk.to_ovk(Scope::External), output, - zip212_enforcement(&st.network(), height), + zip212_enforcement(params, height), ); if result.is_some() { diff --git a/zcash_client_sqlite/src/wallet/scanning.rs b/zcash_client_sqlite/src/wallet/scanning.rs index 64dede7e9e..9367a6829e 100644 --- a/zcash_client_sqlite/src/wallet/scanning.rs +++ b/zcash_client_sqlite/src/wallet/scanning.rs @@ -594,12 +594,14 @@ pub(crate) mod tests { consensus::{BlockHeight, NetworkUpgrade, Parameters}, transaction::components::amount::NonNegativeAmount, }; + use zcash_protocol::local_consensus::LocalNetwork; use crate::{ error::SqliteClientError, testing::{ - pool::ShieldedPoolTester, AddressType, BlockCache, FakeCompactOutput, - InitialChainState, TestBuilder, TestState, + db::{TestDb, TestDbFactory}, + pool::ShieldedPoolTester, + AddressType, BlockCache, FakeCompactOutput, InitialChainState, TestBuilder, TestState, }, wallet::{ sapling::tests::SaplingPoolTester, @@ -646,6 +648,7 @@ pub(crate) mod tests { let initial_height_offset = 310; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { let sapling_activation_height = @@ -728,7 +731,7 @@ pub(crate) mod tests { // Verify the that adjacent range needed to make the note spendable has been prioritized. let sap_active = u32::from(sapling_activation_height); assert_matches!( - st.wallet().suggest_scan_ranges(), + suggest_scan_ranges(st.wallet().conn(), Historic), Ok(scan_ranges) if scan_ranges == vec![ scan_range((sap_active + 300)..(sap_active + 310), FoundNote) ] @@ -736,7 +739,7 @@ pub(crate) mod tests { // Check that the scanned range has been properly persisted. assert_matches!( - suggest_scan_ranges(&st.wallet().conn, Scanned), + suggest_scan_ranges(st.wallet().conn(), Scanned), Ok(scan_ranges) if scan_ranges == vec![ scan_range((sap_active + 300)..(sap_active + 310), FoundNote), scan_range((sap_active + 310)..(sap_active + 320), Scanned) @@ -754,7 +757,7 @@ pub(crate) mod tests { // Check the scan range again, we should see a `ChainTip` range for the period we've been // offline. assert_matches!( - st.wallet().suggest_scan_ranges(), + suggest_scan_ranges(st.wallet().conn(), Historic), Ok(scan_ranges) if scan_ranges == vec![ scan_range((sap_active + 320)..(sap_active + 341), ChainTip), scan_range((sap_active + 300)..(sap_active + 310), ChainTip) @@ -771,7 +774,7 @@ pub(crate) mod tests { // Check the scan range again, we should see a `Validate` range for the previous wallet // tip, and then a `ChainTip` for the remaining range. assert_matches!( - st.wallet().suggest_scan_ranges(), + suggest_scan_ranges(st.wallet().conn(), Historic), Ok(scan_ranges) if scan_ranges == vec![ scan_range((sap_active + 320)..(sap_active + 330), Verify), scan_range((sap_active + 330)..(sap_active + 451), ChainTip), @@ -805,8 +808,14 @@ pub(crate) mod tests { birthday_offset: u32, prior_block_hash: BlockHash, insert_prior_roots: bool, - ) -> (TestState, T::Fvk, AccountBirthday, u32) { + ) -> ( + TestState, + T::Fvk, + AccountBirthday, + u32, + ) { let st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { // We set the Sapling and Orchard frontiers at the birthday height to be @@ -892,7 +901,7 @@ pub(crate) mod tests { // The range up to the wallet's birthday height is ignored. scan_range(sap_active..birthday_height, Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -900,7 +909,10 @@ pub(crate) mod tests { fn update_chain_tip_before_create_account() { use ScanPriority::*; - let mut st = TestBuilder::new().with_block_cache().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .with_block_cache() + .build(); let sap_active = st.sapling_activation_height(); // Update the chain tip. @@ -912,7 +924,7 @@ pub(crate) mod tests { // The range up to the chain end is ignored. scan_range(sap_active.into()..chain_end, Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); // Now add an account. @@ -933,7 +945,7 @@ pub(crate) mod tests { // The range up to the wallet's birthday height is ignored. scan_range(sap_active.into()..wallet_birthday.into(), Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -978,7 +990,7 @@ pub(crate) mod tests { scan_range(sap_active..wallet_birthday, Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -1022,7 +1034,7 @@ pub(crate) mod tests { scan_range(sap_active..birthday.height().into(), Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -1051,6 +1063,7 @@ pub(crate) mod tests { // notes beyond the end of the first shard. let frontier_tree_size: u32 = (0x1 << 16) + 1234; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { let birthday_height = @@ -1123,7 +1136,7 @@ pub(crate) mod tests { ), pre_birthday_range.clone(), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); // Simulate that in the blocks between the wallet birthday and the max_scanned height, @@ -1154,7 +1167,7 @@ pub(crate) mod tests { pre_birthday_range.clone(), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); // Now simulate shutting down, and then restarting 90 blocks later, after a shard @@ -1180,7 +1193,7 @@ pub(crate) mod tests { .unwrap(); // Just inserting the subtree roots doesn't affect the scan ranges. - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); let new_tip = last_shard_start + 20; @@ -1213,7 +1226,7 @@ pub(crate) mod tests { pre_birthday_range, ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -1243,6 +1256,7 @@ pub(crate) mod tests { // notes beyond the end of the first shard. let frontier_tree_size: u32 = (0x1 << 16) + 1234; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { let birthday_height = @@ -1313,7 +1327,7 @@ pub(crate) mod tests { scan_range(sap_active.into()..birthday.height().into(), Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); // Simulate that in the blocks between the wallet birthday and the max_scanned height, @@ -1366,6 +1380,7 @@ pub(crate) mod tests { { let mut shard_stmt = st .wallet_mut() + .db_mut() .conn .prepare("SELECT shard_index, subtree_end_height FROM sapling_tree_shards") .unwrap(); @@ -1381,6 +1396,7 @@ pub(crate) mod tests { { let mut shard_stmt = st .wallet_mut() + .db_mut() .conn .prepare("SELECT shard_index, subtree_end_height FROM orchard_tree_shards") .unwrap(); @@ -1409,7 +1425,7 @@ pub(crate) mod tests { scan_range(sap_active.into()..birthday.height().into(), Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); // We've crossed a subtree boundary, but only in one pool. We still only have one scanned @@ -1427,7 +1443,9 @@ pub(crate) mod tests { fn replace_queue_entries_merges_previous_range() { use ScanPriority::*; - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let ranges = vec![ scan_range(150..200, ChainTip), @@ -1436,16 +1454,16 @@ pub(crate) mod tests { ]; { - let tx = st.wallet_mut().conn.transaction().unwrap(); + let tx = st.wallet_mut().conn_mut().transaction().unwrap(); insert_queue_entries(&tx, ranges.iter()).unwrap(); tx.commit().unwrap(); } - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, ranges); { - let tx = st.wallet_mut().conn.transaction().unwrap(); + let tx = st.wallet_mut().conn_mut().transaction().unwrap(); replace_queue_entries::( &tx, &(BlockHeight::from(150)..BlockHeight::from(160)), @@ -1462,7 +1480,7 @@ pub(crate) mod tests { scan_range(0..100, Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -1470,7 +1488,9 @@ pub(crate) mod tests { fn replace_queue_entries_merges_subsequent_range() { use ScanPriority::*; - let mut st = TestBuilder::new().build(); + let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) + .build(); let ranges = vec![ scan_range(150..200, ChainTip), @@ -1479,16 +1499,16 @@ pub(crate) mod tests { ]; { - let tx = st.wallet_mut().conn.transaction().unwrap(); + let tx = st.wallet_mut().conn_mut().transaction().unwrap(); insert_queue_entries(&tx, ranges.iter()).unwrap(); tx.commit().unwrap(); } - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, ranges); { - let tx = st.wallet_mut().conn.transaction().unwrap(); + let tx = st.wallet_mut().conn_mut().transaction().unwrap(); replace_queue_entries::( &tx, &(BlockHeight::from(90)..BlockHeight::from(100)), @@ -1505,7 +1525,7 @@ pub(crate) mod tests { scan_range(0..90, Ignored), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), Ignored).unwrap(); assert_eq!(actual, expected); } @@ -1534,13 +1554,14 @@ pub(crate) mod tests { #[cfg(feature = "orchard")] fn prepare_orchard_block_spanning_test( with_birthday_subtree_root: bool, - ) -> TestState { + ) -> TestState { let birthday_nu5_offset = 5000; let birthday_prior_block_hash = BlockHash([0; 32]); // We set the Sapling and Orchard frontiers at the birthday block initial state to 50 // notes back from the end of the second shard. let birthday_tree_size: u32 = (0x1 << 17) - 50; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_initial_chain_state(|rng, network| { let birthday_height = @@ -1674,6 +1695,8 @@ pub(crate) mod tests { #[test] #[cfg(feature = "orchard")] fn orchard_block_spanning_tip_boundary_complete() { + use zcash_client_backend::data_api::Account as _; + let mut st = prepare_orchard_block_spanning_test(true); let account = st.test_account().cloned().unwrap(); let birthday = account.birthday(); @@ -1701,27 +1724,24 @@ pub(crate) mod tests { ), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, ScanPriority::Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), ScanPriority::Ignored).unwrap(); assert_eq!(actual, expected); // Scan the chain-tip range. st.scan_cached_blocks(birthday.height() + 12, 112); // We haven't yet discovered our note, so balances should still be zero - assert_eq!( - st.get_total_balance(account.account_id()), - NonNegativeAmount::ZERO - ); + assert_eq!(st.get_total_balance(account.id()), NonNegativeAmount::ZERO); // Now scan the historic range; this should discover our note, which should now be // spendable. st.scan_cached_blocks(birthday.height(), 12); assert_eq!( - st.get_total_balance(account.account_id()), + st.get_total_balance(account.id()), NonNegativeAmount::const_from_u64(100000) ); assert_eq!( - st.get_spendable_balance(account.account_id(), 10), + st.get_spendable_balance(account.id(), 10), NonNegativeAmount::const_from_u64(100000) ); @@ -1729,7 +1749,7 @@ pub(crate) mod tests { let to_extsk = OrchardPoolTester::sk(&[0xf5; 32]); let to = OrchardPoolTester::sk_default_address(&to_extsk); let request = zip321::TransactionRequest::new(vec![zip321::Payment::without_memo( - to.to_zcash_address(&st.network()), + to.to_zcash_address(st.network()), NonNegativeAmount::const_from_u64(10000), )]) .unwrap(); @@ -1747,7 +1767,7 @@ pub(crate) mod tests { let proposal = st .propose_transfer( - account.account_id(), + account.id(), input_selector, request, NonZeroU32::new(10).unwrap(), @@ -1767,6 +1787,8 @@ pub(crate) mod tests { #[test] #[cfg(feature = "orchard")] fn orchard_block_spanning_tip_boundary_incomplete() { + use zcash_client_backend::data_api::Account as _; + let mut st = prepare_orchard_block_spanning_test(false); let account = st.test_account().cloned().unwrap(); let birthday = account.birthday(); @@ -1790,27 +1812,24 @@ pub(crate) mod tests { ), ]; - let actual = suggest_scan_ranges(&st.wallet().conn, ScanPriority::Ignored).unwrap(); + let actual = suggest_scan_ranges(st.wallet().conn(), ScanPriority::Ignored).unwrap(); assert_eq!(actual, expected); // Scan the chain-tip range, but omitting the spanning block. st.scan_cached_blocks(birthday.height() + 13, 112); // We haven't yet discovered our note, so balances should still be zero - assert_eq!( - st.get_total_balance(account.account_id()), - NonNegativeAmount::ZERO - ); + assert_eq!(st.get_total_balance(account.id()), NonNegativeAmount::ZERO); // Now scan the historic range; this should discover our note but not // complete the tree. The note should not be considered spendable. st.scan_cached_blocks(birthday.height(), 12); assert_eq!( - st.get_total_balance(account.account_id()), + st.get_total_balance(account.id()), NonNegativeAmount::const_from_u64(100000) ); assert_eq!( - st.get_spendable_balance(account.account_id(), 10), + st.get_spendable_balance(account.id(), 10), NonNegativeAmount::ZERO ); @@ -1818,7 +1837,7 @@ pub(crate) mod tests { let to_extsk = OrchardPoolTester::sk(&[0xf5; 32]); let to = OrchardPoolTester::sk_default_address(&to_extsk); let request = zip321::TransactionRequest::new(vec![zip321::Payment::without_memo( - to.to_zcash_address(&st.network()), + to.to_zcash_address(st.network()), NonNegativeAmount::const_from_u64(10000), )]) .unwrap(); @@ -1835,7 +1854,7 @@ pub(crate) mod tests { &GreedyInputSelector::new(change_strategy, DustOutputPolicy::default()); let proposal = st.propose_transfer( - account.account_id(), + account.id(), input_selector, request.clone(), NonZeroU32::new(10).unwrap(), @@ -1848,7 +1867,7 @@ pub(crate) mod tests { // Verify that it's now possible to create the proposal let proposal = st.propose_transfer( - account.account_id(), + account.id(), input_selector, request, NonZeroU32::new(10).unwrap(), diff --git a/zcash_client_sqlite/src/wallet/transparent.rs b/zcash_client_sqlite/src/wallet/transparent.rs index 47e70ed572..433b3bed99 100644 --- a/zcash_client_sqlite/src/wallet/transparent.rs +++ b/zcash_client_sqlite/src/wallet/transparent.rs @@ -822,11 +822,16 @@ pub(crate) fn queue_transparent_spend_detection( #[cfg(test)] mod tests { - use crate::testing::{AddressType, TestBuilder, TestState}; + use crate::testing::{ + db::{TestDb, TestDbFactory}, + AddressType, TestBuilder, TestState, + }; + use sapling::zip32::ExtendedSpendingKey; use zcash_client_backend::{ data_api::{ - wallet::input_selection::GreedyInputSelector, InputSource, WalletRead, WalletWrite, + wallet::input_selection::GreedyInputSelector, Account as _, InputSource, WalletRead, + WalletWrite, }, encoding::AddressCodec, fees::{fixed, DustOutputPolicy}, @@ -845,11 +850,12 @@ mod tests { use crate::testing::TestBuilder; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); let birthday = st.test_account().unwrap().birthday().height(); - let account_id = st.test_account().unwrap().account_id(); + let account_id = st.test_account().unwrap().id(); let uaddr = st .wallet() .get_current_address(account_id) @@ -933,10 +939,10 @@ mod tests { // Artificially delete the address from the addresses table so that // we can ensure the update fails if the join doesn't work. st.wallet() - .conn + .conn() .execute( "DELETE FROM addresses WHERE cached_transparent_receiver_address = ?", - [Some(taddr.encode(&st.wallet().params))], + [Some(taddr.encode(st.network()))], ) .unwrap(); @@ -949,6 +955,7 @@ mod tests { use zcash_client_backend::ShieldedProtocol; let mut st = TestBuilder::new() + .with_data_store_factory(TestDbFactory) .with_block_cache() .with_account_from_sapling_activation(BlockHash([0; 32])) .build(); @@ -956,7 +963,7 @@ mod tests { let account = st.test_account().cloned().unwrap(); let uaddr = st .wallet() - .get_current_address(account.account_id()) + .get_current_address(account.id()) .unwrap() .unwrap(); let taddr = uaddr.transparent().unwrap(); @@ -971,17 +978,14 @@ mod tests { } st.scan_cached_blocks(start_height, 10); - let check_balance = |st: &TestState<_>, min_confirmations: u32, expected| { + let check_balance = |st: &TestState<_, TestDb, _>, min_confirmations: u32, expected| { // Check the wallet summary returns the expected transparent balance. let summary = st .wallet() .get_wallet_summary(min_confirmations) .unwrap() .unwrap(); - let balance = summary - .account_balances() - .get(&account.account_id()) - .unwrap(); + let balance = summary.account_balances().get(&account.id()).unwrap(); // TODO: in the future, we will distinguish between available and total // balance according to `min_confirmations` assert_eq!(balance.unshielded(), expected); @@ -990,7 +994,7 @@ mod tests { let mempool_height = st.wallet().chain_height().unwrap().unwrap() + 1; assert_eq!( st.wallet() - .get_transparent_balances(account.account_id(), mempool_height) + .get_transparent_balances(account.id(), mempool_height) .unwrap() .get(taddr) .cloned()