Skip to content

Commit

Permalink
fix: tests for invalid state parts (near#9122)
Browse files Browse the repository at this point in the history
Check that invalid state parts don't pass validation. Add missing `StorageError` kind for one of test cases. Now we check two scenarios:
* MissingTrieValue - should be returned if some trie value is missing during validation - e.g. some intermediate node wasn't sent;
* UnexpectedTrieValue - should be returned if some extra trie value was included to state part. Necessary to check, otherwise malicious actor can spam nodes with large amounts of non-existent nodes.
  • Loading branch information
Longarithm authored and nikurt committed Jun 8, 2023
1 parent 57a13b2 commit b1ea5e7
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 38 deletions.
130 changes: 100 additions & 30 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,9 @@ impl Trie {
Ok(key_nibbles)
}

/// Validate state part
///
/// # Panics
/// part_id must be in [0..num_parts)
///
/// # Errors
/// StorageError::TrieNodeWithMissing if some nodes are missing
pub fn validate_trie_nodes_for_part(
/// Validates state part for given state root.
/// Returns error if state part is invalid and Ok otherwise.
pub fn validate_state_part(
state_root: &StateRoot,
part_id: PartId,
partial_state: PartialState,
Expand All @@ -368,9 +363,9 @@ impl Trie {
let storage = trie.storage.as_partial_storage().unwrap();

if storage.visited_nodes.borrow().len() != num_nodes {
// TODO #1603 not actually TrieNodeMissing.
// The error is that the proof has more nodes than needed.
return Err(StorageError::TrieNodeMissing);
// As all nodes belonging to state part were visited, there is some
// unexpected data in downloaded state part.
return Err(StorageError::UnexpectedTrieValue);
}
Ok(())
}
Expand Down Expand Up @@ -447,6 +442,7 @@ mod tests {
use std::sync::Arc;

use rand::prelude::ThreadRng;
use rand::seq::SliceRandom;
use rand::Rng;

use near_primitives::hash::{hash, CryptoHash};
Expand Down Expand Up @@ -660,7 +656,7 @@ mod tests {
fn test_combine_empty_trie_parts() {
let state_root = Trie::EMPTY_ROOT;
let _ = Trie::combine_state_parts_naive(&state_root, &[]).unwrap();
let _ = Trie::validate_trie_nodes_for_part(
let _ = Trie::validate_state_part(
&state_root,
PartId::new(0, 1),
PartialState::TrieValues(vec![]),
Expand Down Expand Up @@ -888,12 +884,14 @@ mod tests {
assert_eq!(all_nodes.len(), trie_changes.insertions.len());
let size_of_all = all_nodes.iter().map(|node| node.len()).sum::<usize>();
let num_nodes = all_nodes.len();
Trie::validate_trie_nodes_for_part(
trie.get_root(),
PartId::new(0, 1),
PartialState::TrieValues(all_nodes),
)
.expect("validate ok");
assert_eq!(
Trie::validate_state_part(
trie.get_root(),
PartId::new(0, 1),
PartialState::TrieValues(all_nodes),
),
Ok(())
);

let sum_of_sizes = sizes_vec.iter().sum::<usize>();
// Manually check that sizes are reasonable
Expand Down Expand Up @@ -965,9 +963,78 @@ mod tests {
trie_changes
}

/// Checks that state part with unexpected data or not enough data doesn't
/// pass validation.
#[test]
fn invalid_state_parts() {
let tries = create_tries();
let shard_uid = ShardUId::single_shard();
let block_hash = CryptoHash::default();
let part_id = PartId::new(1, 2);
let trie = tries.get_trie_for_shard(shard_uid, Trie::EMPTY_ROOT);

let state_items = vec![
(b"a".to_vec(), vec![1]),
(b"aa".to_vec(), vec![2]),
(b"ab".to_vec(), vec![3]),
(b"b".to_vec(), vec![4]),
(b"ba".to_vec(), vec![5]),
];

let changes_for_trie = state_items.iter().cloned().map(|(k, v)| (k, Some(v)));
let trie_changes = trie.update(changes_for_trie).unwrap();
let mut store_update = tries.store_update();
let root = tries.apply_all(&trie_changes, shard_uid, &mut store_update);
store_update.commit().unwrap();

let trie = tries.get_view_trie_for_shard(shard_uid, root);
let PartialState::TrieValues(trie_values) = trie
.get_trie_nodes_for_part(&block_hash, part_id)
.expect("State part generation using Trie must work");
let num_trie_values = trie_values.len();
assert!(num_trie_values >= 2);

// Check that shuffled state part also passes validation.
let mut rng = rand::thread_rng();
for _ in 0..5 {
let mut trie_values_shuffled = trie_values.clone();
trie_values_shuffled.shuffle(&mut rng);
let state_part = PartialState::TrieValues(trie_values_shuffled);
assert_eq!(Trie::validate_state_part(&root, part_id, state_part), Ok(()));
}

// Remove middle element from state part, check that validation fails.
let mut trie_values_missing = trie_values.clone();
trie_values_missing.remove(num_trie_values / 2);
let wrong_state_part = PartialState::TrieValues(trie_values_missing);
assert_eq!(
Trie::validate_state_part(&root, part_id, wrong_state_part),
Err(StorageError::MissingTrieValue)
);

// Add extra value to the state part, check that validation fails.
let mut trie_values_extra = trie_values.clone();
trie_values_extra.push(vec![11].into());
let wrong_state_part = PartialState::TrieValues(trie_values_extra);
assert_eq!(
Trie::validate_state_part(&root, part_id, wrong_state_part),
Err(StorageError::UnexpectedTrieValue)
);

// Duplicate a value in the state part, check that validation fails, because
// values in state part must be deduplicated.
let mut trie_values_extra_same = trie_values;
trie_values_extra_same
.push(trie_values_extra_same[trie_values_extra_same.len() / 2].clone());
let wrong_state_part = PartialState::TrieValues(trie_values_extra_same);
assert_eq!(
Trie::validate_state_part(&root, part_id, wrong_state_part),
Err(StorageError::UnexpectedTrieValue)
);
}

/// Check on random samples that state parts can be validated independently
/// from the entire trie.
/// TODO (#8997): add custom tests where incorrect parts don't pass validation.
#[test]
fn test_get_trie_nodes_for_part() {
let mut rng = rand::thread_rng();
Expand All @@ -993,16 +1060,19 @@ mod tests {
PartId::new(part_id, num_parts),
)
.unwrap();
Trie::validate_trie_nodes_for_part(
trie.get_root(),
PartId::new(part_id, num_parts),
trie_nodes,
)
.expect("validate ok");
assert_eq!(
Trie::validate_state_part(
trie.get_root(),
PartId::new(part_id, num_parts),
trie_nodes,
),
Ok(())
);
}
}
}

/// Checks sanity of generating state part using flat storage.
#[test]
fn get_trie_nodes_for_part_with_flat_storage() {
let value_len = 1000usize;
Expand All @@ -1013,7 +1083,7 @@ mod tests {
let part_id = PartId::new(1, 3);
let trie = tries.get_trie_for_shard(shard_uid, Trie::EMPTY_ROOT);

// Corner case when trie is a single path from empty string to "aaaa".
// Trie with three big independent children.
let state_items = vec![
(b"a".to_vec(), vec![1; value_len]),
(b"aa".to_vec(), vec![2; value_len]),
Expand Down Expand Up @@ -1043,15 +1113,15 @@ mod tests {
let state_part = trie_without_flat
.get_trie_nodes_for_part(&block_hash, part_id)
.expect("State part generation using Trie must work");
assert_eq!(Trie::validate_trie_nodes_for_part(&root, part_id, state_part.clone()), Ok(()));
assert_eq!(Trie::validate_state_part(&root, part_id, state_part.clone()), Ok(()));
assert!(state_part.len() > 0);

// Check that if we try to use flat storage but it is empty, state part
// creation fails.
let trie = tries.get_trie_with_block_hash_for_shard(shard_uid, root, &block_hash, true);
assert_eq!(
trie.get_trie_nodes_for_part(&block_hash, part_id),
Err(StorageError::TrieNodeMissing)
Err(StorageError::MissingTrieValue)
);

// Fill flat storage and check that state part creation succeeds.
Expand Down Expand Up @@ -1080,7 +1150,7 @@ mod tests {

assert_eq!(
trie_without_flat.get_trie_nodes_for_part(&block_hash, part_id),
Err(StorageError::TrieNodeMissing)
Err(StorageError::MissingTrieValue)
);
assert_eq!(trie_with_flat.get_trie_nodes_for_part(&block_hash, part_id), Ok(state_part));

Expand All @@ -1094,7 +1164,7 @@ mod tests {

assert_eq!(
trie_with_flat.get_trie_nodes_for_part(&block_hash, part_id),
Err(StorageError::TrieNodeMissing)
Err(StorageError::MissingTrieValue)
);
}
}
4 changes: 2 additions & 2 deletions core/store/src/trie/trie_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ pub struct TrieMemoryPartialStorage {

impl TrieStorage for TrieMemoryPartialStorage {
fn retrieve_raw_bytes(&self, hash: &CryptoHash) -> Result<Arc<[u8]>, StorageError> {
let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::TrieNodeMissing);
let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::MissingTrieValue);
if result.is_ok() {
self.visited_nodes.borrow_mut().insert(*hash);
}
Expand Down Expand Up @@ -639,7 +639,7 @@ fn read_node_from_db(
let val = store
.get(DBCol::State, key.as_ref())
.map_err(|_| StorageError::StorageInternalError)?
.ok_or_else(|| StorageError::TrieNodeMissing)?;
.ok_or_else(|| StorageError::MissingTrieValue)?;
Ok(val.into())
}

Expand Down
8 changes: 4 additions & 4 deletions core/store/src/trie/trie_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ impl IncompletePartialStorage {

impl TrieStorage for IncompletePartialStorage {
fn retrieve_raw_bytes(&self, hash: &CryptoHash) -> Result<Arc<[u8]>, StorageError> {
let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::TrieNodeMissing);
let result = self.recorded_storage.get(hash).cloned().ok_or(StorageError::MissingTrieValue);

if result.is_ok() {
self.visited_nodes.borrow_mut().insert(*hash);
}

if self.visited_nodes.borrow().len() > self.node_count_to_fail_after {
Err(StorageError::TrieNodeMissing)
Err(StorageError::MissingTrieValue)
} else {
result
}
Expand Down Expand Up @@ -84,7 +84,7 @@ where
flat_storage_chunk_view: None,
};
let expected_result =
if i < size { Err(&StorageError::TrieNodeMissing) } else { Ok(&expected) };
if i < size { Err(&StorageError::MissingTrieValue) } else { Ok(&expected) };
assert_eq!(test(new_trie).map(|v| v.1).as_ref(), expected_result);
}
println!("Success");
Expand Down Expand Up @@ -278,7 +278,7 @@ mod trie_storage_tests {
let key = hash(&value);

let result = trie_caching_storage.retrieve_raw_bytes(&key);
assert_matches!(result, Err(StorageError::TrieNodeMissing));
assert_matches!(result, Err(StorageError::MissingTrieValue));
}

/// Check that large values does not fall into shard cache, but fall into chunk cache.
Expand Down
2 changes: 1 addition & 1 deletion nearcore/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ impl RuntimeAdapter for NightshadeRuntime {
fn validate_state_part(&self, state_root: &StateRoot, part_id: PartId, data: &[u8]) -> bool {
match BorshDeserialize::try_from_slice(data) {
Ok(trie_nodes) => {
match Trie::validate_trie_nodes_for_part(state_root, part_id, trie_nodes) {
match Trie::validate_state_part(state_root, part_id, trie_nodes) {
Ok(_) => true,
// Storage error should not happen
Err(err) => {
Expand Down
2 changes: 1 addition & 1 deletion tools/state-viewer/src/state_dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ mod test {

/// If the node does not track a shard, state dump will not give the correct result.
#[test]
#[should_panic(expected = "TrieNodeMissing")]
#[should_panic(expected = "MissingTrieValue")]
fn test_dump_state_not_track_shard() {
let epoch_length = 4;
let mut genesis =
Expand Down

0 comments on commit b1ea5e7

Please sign in to comment.