Skip to content

Commit

Permalink
zcash_client_sqlite: Generalize the test framework to enable it to be…
Browse files Browse the repository at this point in the history
… moved to `zcash_client_backend`
  • Loading branch information
nuttycom committed Sep 5, 2024
1 parent 0ae5ac1 commit 6b6b310
Show file tree
Hide file tree
Showing 11 changed files with 742 additions and 457 deletions.
1 change: 1 addition & 0 deletions zcash_client_sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ document-features.workspace = true
maybe-rayon.workspace = true

[dev-dependencies]
ambassador.workspace = true
assert_matches.workspace = true
bls12_381.workspace = true
incrementalmerkletree = { workspace = true, features = ["test-dependencies"] }
Expand Down
144 changes: 86 additions & 58 deletions zcash_client_sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,28 +264,36 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> InputSource for
&self,
account: AccountId,
target_value: NonNegativeAmount,
_sources: &[ShieldedProtocol],
sources: &[ShieldedProtocol],
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error> {
Ok(SpendableNotes::new(
wallet::sapling::select_spendable_sapling_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?,
if sources.contains(&ShieldedProtocol::Sapling) {
wallet::sapling::select_spendable_sapling_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?
} else {
vec![]
},
#[cfg(feature = "orchard")]
wallet::orchard::select_spendable_orchard_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?,
if sources.contains(&ShieldedProtocol::Orchard) {
wallet::orchard::select_spendable_orchard_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?
} else {
vec![]
},
))
}

Expand Down Expand Up @@ -1687,10 +1695,11 @@ mod tests {
};
use zcash_keys::keys::{UnifiedFullViewingKey, UnifiedSpendingKey};
use zcash_primitives::block::BlockHash;
use zcash_protocol::consensus;

use crate::{
error::SqliteClientError,
testing::{TestBuilder, TestState},
testing::{db::TestDbFactory, TestBuilder, TestState},
AccountId, DEFAULT_UA_REQUEST,
};

Expand All @@ -1703,13 +1712,14 @@ mod tests {
#[test]
fn validate_seed() {
let st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.with_account_from_sapling_activation(BlockHash([0; 32]))
.build();
let account = st.test_account().unwrap();

assert!({
st.wallet()
.validate_seed(account.account_id(), st.test_seed().unwrap())
.validate_seed(account.id(), st.test_seed().unwrap())
.unwrap()
});

Expand All @@ -1724,41 +1734,37 @@ mod tests {
// check that passing an invalid seed results in a failure
assert!({
!st.wallet()
.validate_seed(account.account_id(), &SecretVec::new(vec![1u8; 32]))
.validate_seed(account.id(), &SecretVec::new(vec![1u8; 32]))
.unwrap()
});
}

#[test]
pub(crate) fn get_next_available_address() {
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.with_account_from_sapling_activation(BlockHash([0; 32]))
.build();
let account = st.test_account().cloned().unwrap();

let current_addr = st
.wallet()
.get_current_address(account.account_id())
.unwrap();
let current_addr = st.wallet().get_current_address(account.id()).unwrap();
assert!(current_addr.is_some());

let addr2 = st
.wallet_mut()
.get_next_available_address(account.account_id(), DEFAULT_UA_REQUEST)
.get_next_available_address(account.id(), DEFAULT_UA_REQUEST)
.unwrap();
assert!(addr2.is_some());
assert_ne!(current_addr, addr2);

let addr2_cur = st
.wallet()
.get_current_address(account.account_id())
.unwrap();
let addr2_cur = st.wallet().get_current_address(account.id()).unwrap();
assert_eq!(addr2, addr2_cur);
}

#[test]
pub(crate) fn import_account_hd_0() {
let st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.with_account_from_sapling_activation(BlockHash([0; 32]))
.set_account_index(zip32::AccountId::ZERO)
.build();
Expand All @@ -1769,10 +1775,12 @@ mod tests {

#[test]
pub(crate) fn import_account_hd_1_then_2() {
let mut st = TestBuilder::new().build();
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.build();

let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);

Expand All @@ -1797,15 +1805,19 @@ mod tests {
AccountSource::Derived { seed_fingerprint: _, account_index } if account_index == zip32_index_2);
}

fn check_collisions<C>(
st: &mut TestState<C>,
fn check_collisions<C, DbT: WalletWrite, P: consensus::Parameters>(
st: &mut TestState<C, DbT, P>,
ufvk: &UnifiedFullViewingKey,
birthday: &AccountBirthday,
existing_id: AccountId,
) {
_existing_id: AccountId,
) where
DbT::Account: core::fmt::Debug,
{
assert_matches!(
st.wallet_mut().import_account_ufvk(ufvk, birthday, AccountPurpose::Spending),
Err(SqliteClientError::AccountCollision(id)) if id == existing_id);
st.wallet_mut()
.import_account_ufvk(ufvk, birthday, AccountPurpose::Spending),
Err(_)
);

// Remove the transparent component so that we don't have a match on the full UFVK.
// That should still produce an AccountCollision error.
Expand All @@ -1820,8 +1832,13 @@ mod tests {
)
.unwrap();
assert_matches!(
st.wallet_mut().import_account_ufvk(&subset_ufvk, birthday, AccountPurpose::Spending),
Err(SqliteClientError::AccountCollision(id)) if id == existing_id);
st.wallet_mut().import_account_ufvk(
&subset_ufvk,
birthday,
AccountPurpose::Spending
),
Err(_)
);
}

// Remove the Orchard component so that we don't have a match on the full UFVK.
Expand All @@ -1837,17 +1854,24 @@ mod tests {
)
.unwrap();
assert_matches!(
st.wallet_mut().import_account_ufvk(&subset_ufvk, birthday, AccountPurpose::Spending),
Err(SqliteClientError::AccountCollision(id)) if id == existing_id);
st.wallet_mut().import_account_ufvk(
&subset_ufvk,
birthday,
AccountPurpose::Spending
),
Err(_)
);
}
}

#[test]
pub(crate) fn import_account_hd_1_then_conflicts() {
let mut st = TestBuilder::new().build();
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.build();

let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);

Expand All @@ -1869,27 +1893,28 @@ mod tests {

#[test]
pub(crate) fn import_account_ufvk_then_conflicts() {
let mut st = TestBuilder::new().build();
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.build();

let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);

let seed = Secret::new(vec![0u8; 32]);
let zip32_index_0 = zip32::AccountId::ZERO;
let usk =
UnifiedSpendingKey::from_seed(&st.wallet().params, seed.expose_secret(), zip32_index_0)
.unwrap();
let usk = UnifiedSpendingKey::from_seed(st.network(), seed.expose_secret(), zip32_index_0)
.unwrap();
let ufvk = usk.to_unified_full_viewing_key();

let account = st
.wallet_mut()
.import_account_ufvk(&ufvk, &birthday, AccountPurpose::Spending)
.unwrap();
assert_eq!(
ufvk.encode(&st.wallet().params),
account.ufvk().unwrap().encode(&st.wallet().params)
ufvk.encode(st.network()),
account.ufvk().unwrap().encode(st.network())
);

assert_matches!(
Expand All @@ -1908,10 +1933,12 @@ mod tests {

#[test]
pub(crate) fn create_account_then_conflicts() {
let mut st = TestBuilder::new().build();
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.build();

let birthday = AccountBirthday::from_parts(
ChainState::empty(st.wallet().params.sapling.unwrap() - 1, BlockHash([0; 32])),
ChainState::empty(st.network().sapling.unwrap() - 1, BlockHash([0; 32])),
None,
);

Expand All @@ -1933,17 +1960,15 @@ mod tests {
fn transparent_receivers() {
// Add an account to the wallet.
let st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.with_block_cache()
.with_account_from_sapling_activation(BlockHash([0; 32]))
.build();
let account = st.test_account().unwrap();
let ufvk = account.usk().to_unified_full_viewing_key();
let (taddr, _) = account.usk().default_transparent_address();

let receivers = st
.wallet()
.get_transparent_receivers(account.account_id())
.unwrap();
let receivers = st.wallet().get_transparent_receivers(account.id()).unwrap();

// The receiver for the default UA should be in the set.
assert!(receivers.contains_key(
Expand All @@ -1964,15 +1989,18 @@ mod tests {
use zcash_primitives::consensus::NetworkConstants;
use zcash_primitives::zip32;

let mut st = TestBuilder::new().with_fs_block_cache().build();
let mut st = TestBuilder::new()
.with_data_store_factory(TestDbFactory)
.with_fs_block_cache()
.build();

// The BlockMeta DB starts off empty.
assert_eq!(st.cache().get_max_cached_height().unwrap(), None);

// Generate some fake CompactBlocks.
let seed = [0u8; 32];
let hd_account_index = zip32::AccountId::ZERO;
let extsk = sapling::spending_key(&seed, st.wallet().params.coin_type(), hd_account_index);
let extsk = sapling::spending_key(&seed, st.network().coin_type(), hd_account_index);
let dfvk = extsk.to_diversifiable_full_viewing_key();
let (h1, meta1, _) = st.generate_next_block(
&dfvk,
Expand Down
Loading

0 comments on commit 6b6b310

Please sign in to comment.