diff --git a/core/primitives/src/state_record.rs b/core/primitives/src/state_record.rs index fec0d8de2c6..61212dcf090 100644 --- a/core/primitives/src/state_record.rs +++ b/core/primitives/src/state_record.rs @@ -95,6 +95,19 @@ impl StateRecord { _ => unreachable!(), } } + + pub fn get_type_string(&self) -> String { + match self { + StateRecord::Account { .. } => "Account", + StateRecord::Data { .. } => "Data", + StateRecord::Contract { .. } => "Contract", + StateRecord::AccessKey { .. } => "AccessKey", + StateRecord::PostponedReceipt { .. } => "PostponedReceipt", + StateRecord::ReceivedData { .. } => "ReceivedData", + StateRecord::DelayedReceipt { .. } => "DelayedReceipt", + } + .to_string() + } } impl Display for StateRecord { diff --git a/core/primitives/src/trie_key.rs b/core/primitives/src/trie_key.rs index e947d01f727..98daf87b438 100644 --- a/core/primitives/src/trie_key.rs +++ b/core/primitives/src/trie_key.rs @@ -10,7 +10,7 @@ use crate::types::AccountId; pub(crate) const ACCOUNT_DATA_SEPARATOR: u8 = b','; /// Type identifiers used for DB key generation to store values in the key-value storage. -pub(crate) mod col { +pub mod col { /// This column id is used when storing `primitives::account::Account` type about a given /// `account_id`. pub const ACCOUNT: u8 = 0; @@ -375,7 +375,7 @@ pub mod trie_key_parsers { Ok(None) } - fn parse_account_id_from_trie_key_with_separator( + pub fn parse_account_id_from_trie_key_with_separator( col: u8, raw_key: &[u8], col_name: &str, diff --git a/core/store/src/trie/iterator.rs b/core/store/src/trie/iterator.rs index 7af30d4910e..200931f7508 100644 --- a/core/store/src/trie/iterator.rs +++ b/core/store/src/trie/iterator.rs @@ -45,7 +45,7 @@ impl Crumb { /// There are two stacks that we track while iterating: the trail and the key_nibbles. /// The trail is a vector of trie nodes on the path from root node to the node that is /// currently being processed together with processing status - the Crumb. -/// The key_nibbles is a vector of nibbles from the state root not to the node that is +/// The key_nibbles is a vector of nibbles from the state root node to the node that is /// currently being processed. /// The trail and the key_nibbles may have different lengths e.g. an extension trie node /// will add only a single item to the trail but may add multiple nibbles to the key_nibbles. @@ -57,8 +57,17 @@ pub struct TrieIterator<'a> { /// If not `None`, a list of all nodes that the iterator has visited. visited_nodes: Option>>, - /// Max depth of iteration. - max_depth: Option, + /// Prune condition is an optional closure that given the key nibbles + /// decides if the given trie node should be pruned. + /// + /// If the prune conditions returns true for a given node, this node and the + /// whole sub-tree rooted at this node will be pruned and skipped in iteration. + /// + /// Please note that since the iterator supports seeking the prune condition + /// should have the property that if a prefix of a key should be pruned then + /// the key also should be pruned. Otherwise it would be possible to bypass + /// the pruning by seeking inside of the pruned sub-tree. + prune_condition: Option) -> bool>>, } /// The TrieTiem is a tuple of (key, value) of the node. @@ -76,19 +85,22 @@ pub struct TrieTraversalItem { impl<'a> TrieIterator<'a> { #![allow(clippy::new_ret_no_self)] /// Create a new iterator. - pub(super) fn new(trie: &'a Trie, max_depth: Option) -> Result { + pub(super) fn new( + trie: &'a Trie, + prune_condition: Option) -> bool>>, + ) -> Result { let mut r = TrieIterator { trie, trail: Vec::with_capacity(8), key_nibbles: Vec::with_capacity(64), visited_nodes: None, - max_depth, + prune_condition, }; r.descend_into_node(&trie.root)?; Ok(r) } - /// Position the iterator on the first element with key => `key`. + /// Position the iterator on the first element with key >= `key`. pub fn seek_prefix>(&mut self, key: K) -> Result<(), StorageError> { self.seek_nibble_slice(NibbleSlice::new(key.as_ref()), true).map(drop) } @@ -387,8 +399,8 @@ impl<'a> Iterator for TrieIterator<'a> { loop { let iter_step = self.iter_step()?; - let can_process = match self.max_depth { - Some(max_depth) => self.key_nibbles.len() <= max_depth, + let can_process = match &self.prune_condition { + Some(prune_condition) => !prune_condition(&self.key_nibbles), None => true, }; @@ -420,6 +432,7 @@ impl<'a> Iterator for TrieIterator<'a> { mod tests { use std::collections::BTreeMap; + use itertools::Itertools; use rand::seq::SliceRandom; use rand::Rng; @@ -431,6 +444,10 @@ mod tests { use crate::Trie; use near_primitives::shard_layout::ShardUId; + fn value() -> Option> { + Some(vec![0]) + } + /// Checks that for visiting interval of trie nodes first state key is /// included and the last one is excluded. #[test] @@ -452,20 +469,7 @@ mod tests { fn test_iterator() { let mut rng = rand::thread_rng(); for _ in 0..100 { - let tries = create_tries_complex(1, 2); - let shard_uid = ShardUId { version: 1, shard_id: 0 }; - let trie_changes = gen_changes(&mut rng, 10); - let trie_changes = simplify_changes(&trie_changes); - - let mut map = BTreeMap::new(); - for (key, value) in trie_changes.iter() { - if let Some(value) = value { - map.insert(key.clone(), value.clone()); - } - } - let state_root = - test_populate_trie(&tries, &Trie::EMPTY_ROOT, shard_uid, trie_changes.clone()); - let trie = tries.get_trie_for_shard(shard_uid, state_root); + let (trie_changes, map, trie) = gen_random_trie(&mut rng); { let result1: Vec<_> = trie.iter().unwrap().map(Result::unwrap).collect(); @@ -500,6 +504,149 @@ mod tests { } } + #[test] + fn test_iterator_with_prune_condition_base() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let (trie_changes, map, trie) = gen_random_trie(&mut rng); + + // Check that pruning just one key (and it's subtree) works as expected. + for (prune_key, _) in &trie_changes { + let prune_key = prune_key.clone(); + let prune_key_nibbles = NibbleSlice::new(prune_key.as_slice()).iter().collect_vec(); + let prune_condition = + move |key_nibbles: &Vec| key_nibbles.starts_with(&prune_key_nibbles); + + let result1 = trie + .iter_with_prune_condition(Some(Box::new(prune_condition.clone()))) + .unwrap() + .map(Result::unwrap) + .collect_vec(); + + let result2 = map + .iter() + .filter(|(key, _)| { + !prune_condition(&NibbleSlice::new(key).iter().collect_vec()) + }) + .map(|(key, value)| (key.clone(), value.clone())) + .collect_vec(); + + assert_eq!(result1, result2); + } + } + } + + // Check that pruning a node doesn't descend into it's subtree. + // A buggy pruning implementation could still iterate over all the + // nodes but simply not return them. This test makes sure this is + // not the case. + #[test] + fn test_iterator_with_prune_condition_subtree() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let (trie_changes, map, trie) = gen_random_trie(&mut rng); + + // Test pruning by all keys that are present in the trie. + for (prune_key, _) in &trie_changes { + // This prune condition is not valid in a sense that it only + // prunes a single node but not it's subtree. This is + // intentional to test that iterator won't descend into the + // subtree. + let prune_key_nibbles = NibbleSlice::new(prune_key.as_slice()).iter().collect_vec(); + let prune_condition = + move |key_nibbles: &Vec| key_nibbles == &prune_key_nibbles; + // This is how the prune condition should work. + let prune_key_nibbles = NibbleSlice::new(prune_key.as_slice()).iter().collect_vec(); + let proper_prune_condition = + move |key_nibbles: &Vec| key_nibbles.starts_with(&prune_key_nibbles); + + let result1 = trie + .iter_with_prune_condition(Some(Box::new(prune_condition.clone()))) + .unwrap() + .map(Result::unwrap) + .collect_vec(); + let result2 = map + .iter() + .filter(|(key, _)| { + !proper_prune_condition(&NibbleSlice::new(key).iter().collect_vec()) + }) + .map(|(key, value)| (key.clone(), value.clone())) + .collect_vec(); + + assert_eq!(result1, result2); + } + } + } + + // Utility function for testing trie iteration with the prune condition set. + // * `keys` is a list of keys to be inserted into the trie + // * `pruned_keys` is the expected list of keys that should be the result of iteration + fn test_prune_max_depth_impl( + keys: &Vec>, + pruned_keys: &Vec>, + max_depth: usize, + ) { + let shard_uid = ShardUId::single_shard(); + let tries = create_tries(); + let trie_changes = keys.iter().map(|key| (key.clone(), value())).collect(); + let state_root = test_populate_trie(&tries, &Trie::EMPTY_ROOT, shard_uid, trie_changes); + let trie = tries.get_trie_for_shard(shard_uid, state_root); + let iter = trie.iter_with_max_depth(max_depth).unwrap(); + let keys: Vec<_> = iter.map(|item| item.unwrap().0).collect(); + + assert_eq!(&keys, pruned_keys); + } + + #[test] + fn test_prune_max_depth() { + // simple trie with an extension + // extension(11111) + // branch(5, 6) + // leaf(5) leaf(6) + let extension_keys = vec![vec![0x11, 0x11, 0x15], vec![0x11, 0x11, 0x16]]; + // max_depth is expressed in nibbles + // both leaf nodes are at depth 6 (11 11 15) and (11 11 16) + + // pruning by max depth 5 should return an empty result + test_prune_max_depth_impl(&extension_keys, &vec![], 5); + // pruning by max depth 6 should return both leaves + test_prune_max_depth_impl(&extension_keys, &extension_keys, 6); + + // long chain of branches + let chain_keys = vec![ + vec![0x11], + vec![0x11, 0x11], + vec![0x11, 0x11, 0x11], + vec![0x11, 0x11, 0x11, 0x11], + vec![0x11, 0x11, 0x11, 0x11, 0x11], + ]; + test_prune_max_depth_impl(&chain_keys, &vec![], 1); + test_prune_max_depth_impl(&chain_keys, &vec![vec![0x11]], 2); + test_prune_max_depth_impl(&chain_keys, &vec![vec![0x11]], 3); + test_prune_max_depth_impl(&chain_keys, &vec![vec![0x11], vec![0x11, 0x11]], 4); + test_prune_max_depth_impl(&chain_keys, &vec![vec![0x11], vec![0x11, 0x11]], 5); + } + + fn gen_random_trie( + rng: &mut rand::rngs::ThreadRng, + ) -> (Vec<(Vec, Option>)>, BTreeMap, Vec>, Trie) { + let tries = create_tries_complex(1, 2); + let shard_uid = ShardUId { version: 1, shard_id: 0 }; + let trie_changes = gen_changes(rng, 10); + let trie_changes = simplify_changes(&trie_changes); + + let mut map = BTreeMap::new(); + for (key, value) in trie_changes.iter() { + if let Some(value) = value { + map.insert(key.clone(), value.clone()); + } + } + let state_root = + test_populate_trie(&tries, &Trie::EMPTY_ROOT, shard_uid, trie_changes.clone()); + let trie = tries.get_trie_for_shard(shard_uid, state_root); + (trie_changes, map, trie) + } + fn test_get_trie_items( trie: &Trie, map: &BTreeMap, Vec>, diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index 52568a2509d..6e64af32102 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -879,7 +879,17 @@ impl Trie { &'a self, max_depth: usize, ) -> Result, StorageError> { - TrieIterator::new(self, Some(max_depth)) + TrieIterator::new( + self, + Some(Box::new(move |key_nibbles: &Vec| key_nibbles.len() > max_depth)), + ) + } + + pub fn iter_with_prune_condition<'a>( + &'a self, + prune_condition: Option) -> bool>>, + ) -> Result, StorageError> { + TrieIterator::new(self, prune_condition) } pub fn get_trie_nodes_count(&self) -> TrieNodesCount { diff --git a/tools/state-viewer/src/cli.rs b/tools/state-viewer/src/cli.rs index c734be6dd5c..a29b5e2c777 100644 --- a/tools/state-viewer/src/cli.rs +++ b/tools/state-viewer/src/cli.rs @@ -1,15 +1,26 @@ use crate::commands::*; use crate::contract_accounts::ContractAccountFilter; use crate::rocksdb_stats::get_rocksdb_stats; +use near_chain::{ChainStore, ChainStoreAccess}; use near_chain_configs::{GenesisChangeConfig, GenesisValidationMode}; +use near_epoch_manager::EpochManager; use near_primitives::account::id::AccountId; use near_primitives::hash::CryptoHash; use near_primitives::sharding::ChunkHash; +use near_primitives::state_record::{state_record_to_account_id, StateRecord}; +use near_primitives::trie_key::col; +use near_primitives::trie_key::trie_key_parsers::{ + parse_account_id_from_access_key_key, parse_account_id_from_trie_key_with_separator, +}; use near_primitives::types::{BlockHeight, ShardId}; -use near_store::{Mode, NodeStorage, Store, Temperature}; +use near_store::{Mode, NodeStorage, ShardUId, Store, Temperature, Trie, TrieDBStorage}; use nearcore::{load_config, NearConfig}; +use std::cell::RefCell; +use std::collections::HashMap; use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::str::FromStr; +use std::time::Instant; #[derive(clap::Subcommand)] #[clap(subcommand_required = true, arg_required_else_help = true)] @@ -75,6 +86,8 @@ pub enum StateViewerSubCommand { StateChanges(StateChangesCmd), /// Dump or apply state parts. StateParts(StatePartsCmd), + /// Benchmark how long does it take to iterate the trie. + TrieIterationBenchmark(TrieIterationBenchmarkCmd), /// View head of the storage. #[clap(alias = "view_chain")] ViewChain(ViewChainCmd), @@ -138,6 +151,9 @@ impl StateViewerSubCommand { StateViewerSubCommand::StateParts(cmd) => cmd.run(home_dir, near_config, store), StateViewerSubCommand::ViewChain(cmd) => cmd.run(near_config, store), StateViewerSubCommand::ViewTrie(cmd) => cmd.run(store), + StateViewerSubCommand::TrieIterationBenchmark(cmd) => { + cmd.run(home_dir, near_config, store) + } } } } @@ -606,3 +622,299 @@ impl ViewTrieCmd { } } } + +#[derive(Clone)] +pub enum TrieIterationType { + Full, + Shallow, +} + +impl clap::ValueEnum for TrieIterationType { + fn value_variants<'a>() -> &'a [Self] { + &[Self::Full, Self::Shallow] + } + + fn to_possible_value(&self) -> Option { + match self { + Self::Full => Some(clap::builder::PossibleValue::new("full")), + Self::Shallow => Some(clap::builder::PossibleValue::new("shallow")), + } + } +} + +#[derive(Default)] +struct ColumnCountMap(HashMap); + +impl ColumnCountMap { + fn col_to_string(col: u8) -> &'static str { + match col { + col::ACCOUNT => "ACCOUNT", + col::CONTRACT_CODE => "CONTRACT_CODE", + col::DELAYED_RECEIPT => "DELAYED_RECEIPT", + col::DELAYED_RECEIPT_INDICES => "DELAYED_RECEIPT_INDICES", + col::ACCESS_KEY => "ACCESS KEY", + col::CONTRACT_DATA => "CONTRACT DATA", + col::RECEIVED_DATA => "RECEIVED DATA", + col::POSTPONED_RECEIPT_ID => "POSTPONED RECEIPT ID", + col::PENDING_DATA_COUNT => "PENDING DATA COUNT", + col::POSTPONED_RECEIPT => "POSTPONED RECEIPT", + _ => unreachable!(), + } + } +} + +impl std::fmt::Debug for ColumnCountMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut map = f.debug_map(); + for (col, count) in &self.0 { + map.entry(&Self::col_to_string(*col), &count); + } + map.finish() + } +} + +#[derive(Debug)] +pub struct TrieIterationBenchmarkStats { + visited_map: ColumnCountMap, + pruned_map: ColumnCountMap, +} + +impl TrieIterationBenchmarkStats { + pub fn new() -> Self { + Self { visited_map: ColumnCountMap::default(), pruned_map: ColumnCountMap::default() } + } + + pub fn bump_visited(&mut self, col: u8) { + let entry = self.visited_map.0.entry(col).or_insert(0); + *entry += 1; + } + + pub fn bump_pruned(&mut self, col: u8) { + let entry = self.pruned_map.0.entry(col).or_insert(0); + *entry += 1; + } +} + +#[derive(clap::Parser)] +pub struct TrieIterationBenchmarkCmd { + /// The type of trie iteration. + /// - Full will iterate over all trie keys. + /// - Shallow will only iterate until full account id prefix can be parsed + /// in the trie key. Most notably this will skip any keys or data + /// belonging to accounts. + #[clap(long, default_value = "full")] + iteration_type: TrieIterationType, + + /// Limit the number of trie nodes to be iterated. + #[clap(long)] + limit: Option, + + /// Print the trie nodes to stdout. + #[clap(long, default_value = "false")] + print: bool, +} + +impl TrieIterationBenchmarkCmd { + pub fn run(self, _home_dir: &Path, near_config: NearConfig, store: Store) { + let genesis_config = &near_config.genesis.config; + let chain_store = ChainStore::new( + store.clone(), + genesis_config.genesis_height, + near_config.client_config.save_trie_changes, + ); + let head = chain_store.head().unwrap(); + let block = chain_store.get_block(&head.last_block_hash).unwrap(); + let epoch_manager = + EpochManager::new_from_genesis_config(store.clone(), &genesis_config).unwrap(); + let shard_layout = epoch_manager.get_shard_layout(block.header().epoch_id()).unwrap(); + + for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + if chunk_header.height_included() != block.header().height() { + println!("chunk for shard {shard_id} is missing and will be skipped"); + } + } + + for (shard_id, chunk_header) in block.chunks().iter().enumerate() { + let shard_id = shard_id as ShardId; + if chunk_header.height_included() != block.header().height() { + println!("chunk for shard {shard_id} is missing, skipping it"); + continue; + } + let trie = self.get_trie(shard_id, &shard_layout, &chunk_header, &store); + + println!("shard id {shard_id:#?} benchmark starting"); + self.iter_trie(&trie); + println!("shard id {shard_id:#?} benchmark finished"); + } + } + + fn get_trie( + &self, + shard_id: ShardId, + shard_layout: &near_primitives::shard_layout::ShardLayout, + chunk_header: &near_primitives::sharding::ShardChunkHeader, + store: &Store, + ) -> Trie { + let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, shard_layout); + // Note: here we get the previous state root but the shard layout + // corresponds to the current epoch id. In practice shouldn't + // matter as the shard layout doesn't change. + let state_root = chunk_header.prev_state_root(); + let storage = TrieDBStorage::new(store.clone(), shard_uid); + let flat_storage_chunk_view = None; + Trie::new(Rc::new(storage), state_root, flat_storage_chunk_view) + } + + fn iter_trie(&self, trie: &Trie) { + let stats = Rc::new(RefCell::new(TrieIterationBenchmarkStats::new())); + let stats_clone = Rc::clone(&stats); + + let prune_condition: Option) -> bool>> = match &self.iteration_type { + TrieIterationType::Full => None, + TrieIterationType::Shallow => Some(Box::new(move |key_nibbles| -> bool { + Self::shallow_iter_prune_condition(key_nibbles, &stats_clone) + })), + }; + + let start = Instant::now(); + let mut node_count = 0; + let mut error_count = 0; + let iter = trie.iter_with_prune_condition(prune_condition); + let iter = match iter { + Ok(iter) => iter, + Err(err) => { + println!("iter error {err:#?}"); + return; + } + }; + for item in iter { + node_count += 1; + + let (key, value) = match item { + Ok((key, value)) => (key, value), + Err(err) => { + println!("Failed to iterate node with error: {err}"); + error_count += 1; + continue; + } + }; + + stats.borrow_mut().bump_visited(key[0]); + + if self.print { + let state_record = StateRecord::from_raw_key_value(key.clone(), value); + Self::print_state_record(&state_record); + } + + if let Some(limit) = self.limit { + if limit < node_count { + break; + } + } + } + let duration = start.elapsed(); + println!("node count {node_count}"); + println!("error count {error_count}"); + println!("time {duration:?}"); + println!("stats\n{:#?}", stats.borrow()); + } + + fn shallow_iter_prune_condition( + key_nibbles: &Vec, + stats: &Rc>, + ) -> bool { + // Need at least 2 nibbles for the column type byte. + if key_nibbles.len() < 2 { + return false; + } + + // The key method will drop the last nibble if there is an odd number of + // them. This is on purpose because the interesting keys have even length. + let key = Self::key(key_nibbles); + let col: u8 = key[0]; + let result = match col { + // key for account only contains account id, nothing to prune + col::ACCOUNT => false, + // key for contract code only contains account id, nothing to prune + col::CONTRACT_CODE => false, + // key for delayed receipt only contains account id, nothing to prune + col::DELAYED_RECEIPT => false, + // key for delayed receipt indices is a shard singleton, nothing to prune + col::DELAYED_RECEIPT_INDICES => false, + + // Most columns use the ACCOUNT_DATA_SEPARATOR to indicate the end + // of the accound id in the trie key. For those columns the + // partial_parse_account_id method should be used. + // The only exception is the ACCESS_KEY and dedicated method + // partial_parse_account_id_from_access_key should be used. + col::ACCESS_KEY => Self::partial_parse_account_id_from_access_key(&key, "ACCESS KEY"), + col::CONTRACT_DATA => Self::partial_parse_account_id(col, &key, "CONTRACT DATA"), + col::RECEIVED_DATA => Self::partial_parse_account_id(col, &key, "RECEIVED DATA"), + col::POSTPONED_RECEIPT_ID => { + Self::partial_parse_account_id(col, &key, "POSTPONED RECEIPT ID") + } + col::PENDING_DATA_COUNT => { + Self::partial_parse_account_id(col, &key, "PENDING DATA COUNT") + } + col::POSTPONED_RECEIPT => { + Self::partial_parse_account_id(col, &key, "POSTPONED RECEIPT") + } + _ => unreachable!(), + }; + + if result { + stats.borrow_mut().bump_pruned(col); + } + + result + + // TODO - this can be optimized, we really only need to look at the last + // byte of the key and check if it is the separator. This only works + // when doing full iteration as seeking inside of the trie would break + // the invariant that parent node key was already checked. + } + + fn key(key_nibbles: &Vec) -> Vec { + // Intentionally ignoring the odd nibble at the end. + let mut result = >::with_capacity(key_nibbles.len() / 2); + for i in (1..key_nibbles.len()).step_by(2) { + result.push(key_nibbles[i - 1] * 16 + key_nibbles[i]); + } + result + } + + fn partial_parse_account_id(col: u8, key: &Vec, col_name: &str) -> bool { + match parse_account_id_from_trie_key_with_separator(col, &key, "") { + Ok(account_id) => { + tracing::trace!(target: "trie-iteration-benchmark", "pruning column {col_name} account id {account_id:?}"); + true + } + Err(_) => false, + } + } + + // returns true if the partial key contains full account id + fn partial_parse_account_id_from_access_key(key: &Vec, col_name: &str) -> bool { + match parse_account_id_from_access_key_key(&key) { + Ok(account_id) => { + tracing::trace!(target: "trie-iteration-benchmark", "pruning column {col_name} account id {account_id:?}"); + true + } + Err(_) => false, + } + } + + fn print_state_record(state_record: &Option) { + let state_record_string = match state_record { + None => "none".to_string(), + Some(state_record) => { + format!( + "{} {:?}", + &state_record.get_type_string(), + state_record_to_account_id(&state_record) + ) + } + }; + println!("{state_record_string}"); + } +}