diff --git a/core/store/src/trie/state_parts.rs b/core/store/src/trie/state_parts.rs index 186c7628532..fd3be3ecc59 100644 --- a/core/store/src/trie/state_parts.rs +++ b/core/store/src/trie/state_parts.rs @@ -347,14 +347,9 @@ impl Trie { Ok(key_nibbles) } - /// Validate state part - /// - /// # Panics - /// part_id must be in [0..num_parts) - /// - /// # Errors - /// StorageError::TrieNodeWithMissing if some nodes are missing - pub fn validate_trie_nodes_for_part( + /// Validates state part for given state root. + /// Returns error if state part is invalid and Ok otherwise. + pub fn validate_state_part( state_root: &StateRoot, part_id: PartId, partial_state: PartialState, @@ -368,9 +363,9 @@ impl Trie { let storage = trie.storage.as_partial_storage().unwrap(); if storage.visited_nodes.borrow().len() != num_nodes { - // TODO #1603 not actually TrieNodeMissing. - // The error is that the proof has more nodes than needed. - return Err(StorageError::TrieNodeMissing); + // As all nodes belonging to state part were visited, there is some + // unexpected data in downloaded state part. + return Err(StorageError::UnexpectedTrieValue); } Ok(()) } @@ -447,6 +442,7 @@ mod tests { use std::sync::Arc; use rand::prelude::ThreadRng; + use rand::seq::SliceRandom; use rand::Rng; use near_primitives::hash::{hash, CryptoHash}; @@ -660,7 +656,7 @@ mod tests { fn test_combine_empty_trie_parts() { let state_root = Trie::EMPTY_ROOT; let _ = Trie::combine_state_parts_naive(&state_root, &[]).unwrap(); - let _ = Trie::validate_trie_nodes_for_part( + let _ = Trie::validate_state_part( &state_root, PartId::new(0, 1), PartialState::TrieValues(vec![]), @@ -888,12 +884,14 @@ mod tests { assert_eq!(all_nodes.len(), trie_changes.insertions.len()); let size_of_all = all_nodes.iter().map(|node| node.len()).sum::(); let num_nodes = all_nodes.len(); - Trie::validate_trie_nodes_for_part( - trie.get_root(), - PartId::new(0, 1), - PartialState::TrieValues(all_nodes), - ) - .expect("validate ok"); + assert_eq!( + Trie::validate_state_part( + trie.get_root(), + PartId::new(0, 1), + PartialState::TrieValues(all_nodes), + ), + Ok(()) + ); let sum_of_sizes = sizes_vec.iter().sum::(); // Manually check that sizes are reasonable @@ -965,9 +963,78 @@ mod tests { trie_changes } + /// Checks that state part with unexpected data or not enough data doesn't + /// pass validation. + #[test] + fn invalid_state_parts() { + let tries = create_tries(); + let shard_uid = ShardUId::single_shard(); + let block_hash = CryptoHash::default(); + let part_id = PartId::new(1, 2); + let trie = tries.get_trie_for_shard(shard_uid, Trie::EMPTY_ROOT); + + let state_items = vec![ + (b"a".to_vec(), vec![1]), + (b"aa".to_vec(), vec![2]), + (b"ab".to_vec(), vec![3]), + (b"b".to_vec(), vec![4]), + (b"ba".to_vec(), vec![5]), + ]; + + let changes_for_trie = state_items.iter().cloned().map(|(k, v)| (k, Some(v))); + let trie_changes = trie.update(changes_for_trie).unwrap(); + let mut store_update = tries.store_update(); + let root = tries.apply_all(&trie_changes, shard_uid, &mut store_update); + store_update.commit().unwrap(); + + let trie = tries.get_view_trie_for_shard(shard_uid, root); + let PartialState::TrieValues(trie_values) = trie + .get_trie_nodes_for_part(&block_hash, part_id) + .expect("State part generation using Trie must work"); + let num_trie_values = trie_values.len(); + assert!(num_trie_values >= 2); + + // Check that shuffled state part also passes validation. + let mut rng = rand::thread_rng(); + for _ in 0..5 { + let mut trie_values_shuffled = trie_values.clone(); + trie_values_shuffled.shuffle(&mut rng); + let state_part = PartialState::TrieValues(trie_values_shuffled); + assert_eq!(Trie::validate_state_part(&root, part_id, state_part), Ok(())); + } + + // Remove middle element from state part, check that validation fails. + let mut trie_values_missing = trie_values.clone(); + trie_values_missing.remove(num_trie_values / 2); + let wrong_state_part = PartialState::TrieValues(trie_values_missing); + assert_eq!( + Trie::validate_state_part(&root, part_id, wrong_state_part), + Err(StorageError::MissingTrieValue) + ); + + // Add extra value to the state part, check that validation fails. + let mut trie_values_extra = trie_values.clone(); + trie_values_extra.push(vec![11].into()); + let wrong_state_part = PartialState::TrieValues(trie_values_extra); + assert_eq!( + Trie::validate_state_part(&root, part_id, wrong_state_part), + Err(StorageError::UnexpectedTrieValue) + ); + + // Duplicate a value in the state part, check that validation fails, because + // values in state part must be deduplicated. + let mut trie_values_extra_same = trie_values; + trie_values_extra_same + .push(trie_values_extra_same[trie_values_extra_same.len() / 2].clone()); + let wrong_state_part = PartialState::TrieValues(trie_values_extra_same); + assert_eq!( + Trie::validate_state_part(&root, part_id, wrong_state_part), + Err(StorageError::UnexpectedTrieValue) + ); + } + /// Check on random samples that state parts can be validated independently /// from the entire trie. - /// TODO (#8997): add custom tests where incorrect parts don't pass validation. #[test] fn test_get_trie_nodes_for_part() { let mut rng = rand::thread_rng(); @@ -993,16 +1060,19 @@ mod tests { PartId::new(part_id, num_parts), ) .unwrap(); - Trie::validate_trie_nodes_for_part( - trie.get_root(), - PartId::new(part_id, num_parts), - trie_nodes, - ) - .expect("validate ok"); + assert_eq!( + Trie::validate_state_part( + trie.get_root(), + PartId::new(part_id, num_parts), + trie_nodes, + ), + Ok(()) + ); } } } + /// Checks sanity of generating state part using flat storage. #[test] fn get_trie_nodes_for_part_with_flat_storage() { let value_len = 1000usize; @@ -1013,7 +1083,7 @@ mod tests { let part_id = PartId::new(1, 3); let trie = tries.get_trie_for_shard(shard_uid, Trie::EMPTY_ROOT); - // Corner case when trie is a single path from empty string to "aaaa". + // Trie with three big independent children. let state_items = vec![ (b"a".to_vec(), vec![1; value_len]), (b"aa".to_vec(), vec![2; value_len]), @@ -1043,7 +1113,7 @@ mod tests { let state_part = trie_without_flat .get_trie_nodes_for_part(&block_hash, part_id) .expect("State part generation using Trie must work"); - assert_eq!(Trie::validate_trie_nodes_for_part(&root, part_id, state_part.clone()), Ok(())); + assert_eq!(Trie::validate_state_part(&root, part_id, state_part.clone()), Ok(())); assert!(state_part.len() > 0); // Check that if we try to use flat storage but it is empty, state part @@ -1051,7 +1121,7 @@ mod tests { let trie = tries.get_trie_with_block_hash_for_shard(shard_uid, root, &block_hash, true); assert_eq!( trie.get_trie_nodes_for_part(&block_hash, part_id), - Err(StorageError::TrieNodeMissing) + Err(StorageError::MissingTrieValue) ); // Fill flat storage and check that state part creation succeeds. @@ -1080,7 +1150,7 @@ mod tests { assert_eq!( trie_without_flat.get_trie_nodes_for_part(&block_hash, part_id), - Err(StorageError::TrieNodeMissing) + Err(StorageError::MissingTrieValue) ); assert_eq!(trie_with_flat.get_trie_nodes_for_part(&block_hash, part_id), Ok(state_part)); @@ -1094,7 +1164,7 @@ mod tests { assert_eq!( trie_with_flat.get_trie_nodes_for_part(&block_hash, part_id), - Err(StorageError::TrieNodeMissing) + Err(StorageError::MissingTrieValue) ); } } diff --git a/core/store/src/trie/trie_storage.rs b/core/store/src/trie/trie_storage.rs index 48fb9a0e4d4..cd3f0e64277 100644 --- a/core/store/src/trie/trie_storage.rs +++ b/core/store/src/trie/trie_storage.rs @@ -341,7 +341,7 @@ pub struct TrieMemoryPartialStorage { impl TrieStorage for TrieMemoryPartialStorage { fn retrieve_raw_bytes(&self, hash: &CryptoHash) -> Result, StorageError> { - let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::TrieNodeMissing); + let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::MissingTrieValue); if result.is_ok() { self.visited_nodes.borrow_mut().insert(*hash); } @@ -639,7 +639,7 @@ fn read_node_from_db( let val = store .get(DBCol::State, key.as_ref()) .map_err(|_| StorageError::StorageInternalError)? - .ok_or_else(|| StorageError::TrieNodeMissing)?; + .ok_or_else(|| StorageError::MissingTrieValue)?; Ok(val.into()) } diff --git a/core/store/src/trie/trie_tests.rs b/core/store/src/trie/trie_tests.rs index 305684a4f8c..8ae1aef7f91 100644 --- a/core/store/src/trie/trie_tests.rs +++ b/core/store/src/trie/trie_tests.rs @@ -35,14 +35,14 @@ impl IncompletePartialStorage { impl TrieStorage for IncompletePartialStorage { fn retrieve_raw_bytes(&self, hash: &CryptoHash) -> Result, StorageError> { - let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::TrieNodeMissing); + let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::MissingTrieValue); if result.is_ok() { self.visited_nodes.borrow_mut().insert(*hash); } if self.visited_nodes.borrow().len() > self.node_count_to_fail_after { - Err(StorageError::TrieNodeMissing) + Err(StorageError::MissingTrieValue) } else { result } @@ -84,7 +84,7 @@ where flat_storage_chunk_view: None, }; let expected_result = - if i < size { Err(&StorageError::TrieNodeMissing) } else { Ok(&expected) }; + if i < size { Err(&StorageError::MissingTrieValue) } else { Ok(&expected) }; assert_eq!(test(new_trie).map(|v| v.1).as_ref(), expected_result); } println!("Success"); @@ -278,7 +278,7 @@ mod trie_storage_tests { let key = hash(&value); let result = trie_caching_storage.retrieve_raw_bytes(&key); - assert_matches!(result, Err(StorageError::TrieNodeMissing)); + assert_matches!(result, Err(StorageError::MissingTrieValue)); } /// Check that large values does not fall into shard cache, but fall into chunk cache. diff --git a/nearcore/src/runtime/mod.rs b/nearcore/src/runtime/mod.rs index f1356b14af3..a2f546de82b 100644 --- a/nearcore/src/runtime/mod.rs +++ b/nearcore/src/runtime/mod.rs @@ -1198,7 +1198,7 @@ impl RuntimeAdapter for NightshadeRuntime { fn validate_state_part(&self, state_root: &StateRoot, part_id: PartId, data: &[u8]) -> bool { match BorshDeserialize::try_from_slice(data) { Ok(trie_nodes) => { - match Trie::validate_trie_nodes_for_part(state_root, part_id, trie_nodes) { + match Trie::validate_state_part(state_root, part_id, trie_nodes) { Ok(_) => true, // Storage error should not happen Err(err) => { diff --git a/tools/state-viewer/src/state_dump.rs b/tools/state-viewer/src/state_dump.rs index 536498d9cfc..9e216c770d4 100644 --- a/tools/state-viewer/src/state_dump.rs +++ b/tools/state-viewer/src/state_dump.rs @@ -669,7 +669,7 @@ mod test { /// If the node does not track a shard, state dump will not give the correct result. #[test] - #[should_panic(expected = "TrieNodeMissing")] + #[should_panic(expected = "MissingTrieValue")] fn test_dump_state_not_track_shard() { let epoch_length = 4; let mut genesis =