diff --git a/tools/state-viewer/src/state_parts.rs b/tools/state-viewer/src/state_parts.rs index 07e6e439541..d985da5668d 100644 --- a/tools/state-viewer/src/state_parts.rs +++ b/tools/state-viewer/src/state_parts.rs @@ -1,8 +1,9 @@ use crate::epoch_info::iterate_and_filter; use borsh::BorshDeserialize; use near_chain::{Chain, ChainGenesis, ChainStoreAccess, DoomslugThresholdMode}; -use near_client::sync::state::StateSync; -use near_primitives::challenge::PartialState; +use near_client::sync::state::{ + get_num_parts_from_filename, is_part_filename, location_prefix, part_filename, StateSync, +}; use near_primitives::epoch_manager::epoch_info::EpochInfo; use near_primitives::state_part::PartId; use near_primitives::state_record::StateRecord; @@ -19,26 +20,13 @@ use std::path::{Path, PathBuf}; use std::str::FromStr; use std::time::Instant; -#[derive(clap::ArgEnum, Debug, Clone)] -pub(crate) enum ApplyAction { - Apply, - Validate, - Print, -} - -impl Default for ApplyAction { - fn default() -> Self { - ApplyAction::Apply - } -} - #[derive(clap::Subcommand, Debug, Clone)] pub(crate) enum StatePartsSubCommand { /// Apply all or a single state part of a shard. Apply { - /// Apply, validate or print. - #[clap(arg_enum, long)] - action: ApplyAction, + /// If true, validate the state part but don't write it to the DB. + #[clap(long)] + dry_run: bool, /// If provided, this value will be used instead of looking it up in the headers. /// Use if those headers or blocks are not available. #[clap(long)] @@ -93,12 +81,12 @@ impl StatePartsSubCommand { .unwrap(); let chain_id = &near_config.genesis.config.chain_id; match self { - StatePartsSubCommand::Apply { action, state_root, part_id, epoch_selection } => { + StatePartsSubCommand::Apply { dry_run, state_root, part_id, epoch_selection } => { apply_state_parts( - action, epoch_selection, shard_id, part_id, + dry_run, state_root, &mut chain, chain_id, @@ -146,7 +134,7 @@ impl EpochSelection { chain.runtime_adapter.get_epoch_id(&chain.head().unwrap().last_block_hash).unwrap() } EpochSelection::EpochId { epoch_id } => { - EpochId(CryptoHash::from_str(&epoch_id).unwrap()) + EpochId(CryptoHash::from_str(epoch_id).unwrap()) } EpochSelection::EpochHeight { epoch_height } => { // Fetch epochs at the given height. @@ -159,7 +147,7 @@ impl EpochSelection { epoch_ids[0].clone() } EpochSelection::BlockHash { block_hash } => { - let block_hash = CryptoHash::from_str(&block_hash).unwrap(); + let block_hash = CryptoHash::from_str(block_hash).unwrap(); chain.runtime_adapter.get_epoch_id(&block_hash).unwrap() } EpochSelection::BlockHeight { block_height } => { @@ -233,10 +221,10 @@ fn get_any_block_hash_of_epoch(epoch_info: &EpochInfo, chain: &Chain) -> CryptoH } fn apply_state_parts( - action: ApplyAction, epoch_selection: EpochSelection, shard_id: ShardId, part_id: Option, + dry_run: bool, maybe_state_root: Option, chain: &mut Chain, chain_id: &str, @@ -249,11 +237,11 @@ fn apply_state_parts( { (state_root, *epoch_height, None, None) } else { - let epoch_id = epoch_selection.to_epoch_id(store, &chain); + let epoch_id = epoch_selection.to_epoch_id(store, chain); let epoch = chain.runtime_adapter.get_epoch_info(&epoch_id).unwrap(); - let sync_hash = get_any_block_hash_of_epoch(&epoch, &chain); - let sync_hash = StateSync::get_epoch_start_sync_hash(&chain, &sync_hash).unwrap(); + let sync_hash = get_any_block_hash_of_epoch(&epoch, chain); + let sync_hash = StateSync::get_epoch_start_sync_hash(chain, &sync_hash).unwrap(); let state_header = chain.get_state_response_header(shard_id, sync_hash).unwrap(); let state_root = state_header.chunk_prev_state_root(); @@ -261,7 +249,7 @@ fn apply_state_parts( (state_root, epoch.epoch_height(), Some(epoch_id), Some(sync_hash)) }; - let part_storage = get_state_part_reader(location, &chain_id, epoch_height, shard_id); + let part_storage = get_state_part_reader(location, chain_id, epoch_height, shard_id); let num_parts = part_storage.num_parts(); assert_ne!(num_parts, 0, "Too few num_parts: {}", num_parts); @@ -282,50 +270,38 @@ fn apply_state_parts( assert!(part_id < num_parts, "part_id: {}, num_parts: {}", part_id, num_parts); let part = part_storage.read(part_id, num_parts); - match action { - ApplyAction::Apply => { - chain - .set_state_part( - shard_id, - sync_hash.unwrap(), - PartId::new(part_id, num_parts), - &part, - ) - .unwrap(); - chain - .runtime_adapter - .apply_state_part( - shard_id, - &state_root, - PartId::new(part_id, num_parts), - &part, - epoch_id.as_ref().unwrap(), - ) - .unwrap(); - tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Applied a state part"); - } - ApplyAction::Validate => { - assert!(chain.runtime_adapter.validate_state_part( + if dry_run { + assert!(chain.runtime_adapter.validate_state_part( + &state_root, + PartId::new(part_id, num_parts), + &part + )); + tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Validated a state part"); + } else { + chain + .set_state_part( + shard_id, + sync_hash.unwrap(), + PartId::new(part_id, num_parts), + &part, + ) + .unwrap(); + chain + .runtime_adapter + .apply_state_part( + shard_id, &state_root, PartId::new(part_id, num_parts), - &part - )); - tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Validated a state part"); - } - ApplyAction::Print => { - print_state_part(&state_root, PartId::new(part_id, num_parts), &part) - } + &part, + epoch_id.as_ref().unwrap(), + ) + .unwrap(); + tracing::info!(target: "state-parts", part_id, part_length = part.len(), elapsed_sec = timer.elapsed().as_secs_f64(), "Applied a state part"); } } tracing::info!(target: "state-parts", total_elapsed_sec = timer.elapsed().as_secs_f64(), "Applied all requested state parts"); } -fn print_state_part(state_root: &StateRoot, _part_id: PartId, data: &[u8]) { - let trie_nodes: PartialState = BorshDeserialize::try_from_slice(data).unwrap(); - let trie = Trie::from_recorded_storage(PartialStorage { nodes: trie_nodes }, *state_root); - trie.print_recursive(&mut std::io::stdout().lock(), &state_root, u32::MAX); -} - fn dump_state_parts( epoch_selection: EpochSelection, shard_id: ShardId, @@ -336,10 +312,12 @@ fn dump_state_parts( store: Store, location: Location, ) { - let epoch_id = epoch_selection.to_epoch_id(store, &chain); + let epoch_id = epoch_selection.to_epoch_id(store, chain); let epoch = chain.runtime_adapter.get_epoch_info(&epoch_id).unwrap(); - let sync_hash = get_any_block_hash_of_epoch(&epoch, &chain); - let sync_hash = StateSync::get_epoch_start_sync_hash(&chain, &sync_hash).unwrap(); + let sync_hash = get_any_block_hash_of_epoch(&epoch, chain); + let sync_hash = StateSync::get_epoch_start_sync_hash(chain, &sync_hash).unwrap(); + let sync_block = chain.get_block_header(&sync_hash).unwrap(); + let sync_prev_hash = sync_block.prev_hash(); let state_header = chain.compute_state_response_header(shard_id, sync_hash).unwrap(); let state_root = state_header.chunk_prev_state_root(); @@ -366,7 +344,12 @@ fn dump_state_parts( assert!(part_id < num_parts, "part_id: {}, num_parts: {}", part_id, num_parts); let state_part = chain .runtime_adapter - .obtain_state_part(shard_id, &sync_hash, &state_root, PartId::new(part_id, num_parts)) + .obtain_state_part( + shard_id, + &sync_prev_hash, + &state_root, + PartId::new(part_id, num_parts), + ) .unwrap(); part_storage.write(&state_part, part_id, num_parts); let elapsed_sec = timer.elapsed().as_secs_f64(); @@ -376,7 +359,7 @@ fn dump_state_parts( part_id, part_length = state_part.len(), elapsed_sec, - first_state_record = ?first_state_record.map(|sr| format!("{}", sr)), + ?first_state_record, "Wrote a state part"); } tracing::info!(target: "state-parts", total_elapsed_sec = timer.elapsed().as_secs_f64(), "Wrote all requested state parts"); @@ -387,11 +370,9 @@ fn get_first_state_record(state_root: &StateRoot, data: &[u8]) -> Option, part_to: Option, num_parts: u64) -> part_from.unwrap_or(0)..part_to.unwrap_or(num_parts) } -// Needs to be in sync with `fn s3_location()`. -fn location_prefix(chain_id: &str, epoch_height: u64, shard_id: u64) -> String { - format!("chain_id={}/epoch_height={}/shard_id={}", chain_id, epoch_height, shard_id) -} - -fn match_filename(s: &str) -> Option { - let re = regex::Regex::new(r"^state_part_(\d{6})_of_(\d{6})$").unwrap(); - re.captures(s) -} - -fn is_part_filename(s: &str) -> bool { - match_filename(s).is_some() -} - -fn get_num_parts_from_filename(s: &str) -> Option { - if let Some(captures) = match_filename(s) { - if let Some(num_parts) = captures.get(2) { - if let Ok(num_parts) = num_parts.as_str().parse::() { - return Some(num_parts); - } - } - } - None -} - -fn part_filename(part_id: u64, num_parts: u64) -> String { - format!("state_part_{:06}_of_{:06}", part_id, num_parts) -} - trait StatePartWriter { fn write(&self, state_part: &[u8], part_id: u64, num_parts: u64); } @@ -511,7 +463,7 @@ impl FileSystemStorage { } fn get_location(&self, part_id: u64, num_parts: u64) -> PathBuf { - (&self.state_parts_dir).join(part_filename(part_id, num_parts)) + self.state_parts_dir.join(part_filename(part_id, num_parts)) } } @@ -527,8 +479,7 @@ impl StatePartReader for FileSystemStorage { fn read(&self, part_id: u64, num_parts: u64) -> Vec { let filename = self.get_location(part_id, num_parts); tracing::debug!(target: "state-parts", part_id, num_parts, ?filename, "Reading state part file"); - let part = std::fs::read(filename).unwrap(); - part + std::fs::read(filename).unwrap() } fn num_parts(&self) -> u64 { @@ -576,7 +527,7 @@ impl S3Storage { ) -> Self { let location = location_prefix(chain_id, epoch_height, shard_id); let bucket = s3::Bucket::new( - &s3_bucket, + s3_bucket, s3_region.parse::().unwrap(), s3::creds::Credentials::default().unwrap(), ) @@ -594,7 +545,7 @@ impl S3Storage { impl StatePartWriter for S3Storage { fn write(&self, state_part: &[u8], part_id: u64, num_parts: u64) { let location = self.get_location(part_id, num_parts); - self.bucket.put_object_blocking(&location, &state_part).unwrap(); + self.bucket.put_object_blocking(&location, state_part).unwrap(); tracing::info!(target: "state-parts", part_id, part_length = state_part.len(), ?location, "Wrote a state part to S3"); } }