diff --git a/crates/trie/db/tests/witness.rs b/crates/trie/db/tests/witness.rs index 59656383d..20f8cfbb9 100644 --- a/crates/trie/db/tests/witness.rs +++ b/crates/trie/db/tests/witness.rs @@ -6,6 +6,8 @@ use alloy_primitives::{ Address, Bytes, B256, U256, }; use alloy_rlp::EMPTY_STRING_CODE; +use reth_db::{cursor::DbCursorRW, tables}; +use reth_db_api::transaction::DbTxMut; use reth_primitives::{constants::EMPTY_ROOT_HASH, Account, StorageEntry}; use reth_provider::{test_utils::create_test_provider_factory, HashingWriter}; use reth_trie::{proof::Proof, witness::TrieWitness, HashedPostState, HashedStorage, StateRoot}; @@ -91,3 +93,53 @@ fn includes_nodes_for_destroyed_storage_nodes() { assert_eq!(witness.get(&keccak256(node)), Some(node)); } } + +#[test] +fn correctly_decodes_branch_node_values() { + let factory = create_test_provider_factory(); + let provider = factory.provider_rw().unwrap(); + + let address = Address::random(); + let hashed_address = keccak256(address); + let hashed_slot1 = B256::with_last_byte(1); + let hashed_slot2 = B256::with_last_byte(2); + + // Insert account and slots into database + provider.insert_account_for_hashing([(address, Some(Account::default()))]).unwrap(); + let mut hashed_storage_cursor = + provider.tx_ref().cursor_dup_write::().unwrap(); + hashed_storage_cursor + .upsert(hashed_address, StorageEntry { key: hashed_slot1, value: U256::from(1) }) + .unwrap(); + hashed_storage_cursor + .upsert(hashed_address, StorageEntry { key: hashed_slot2, value: U256::from(1) }) + .unwrap(); + + let state_root = StateRoot::from_tx(provider.tx_ref()).root().unwrap(); + let multiproof = Proof::from_tx(provider.tx_ref()) + .multiproof(HashMap::from_iter([( + hashed_address, + HashSet::from_iter([hashed_slot1, hashed_slot2]), + )])) + .unwrap(); + + let witness = TrieWitness::from_tx(provider.tx_ref()) + .compute(HashedPostState { + accounts: HashMap::from([(hashed_address, Some(Account::default()))]), + storages: HashMap::from([( + hashed_address, + HashedStorage::from_iter( + false, + [hashed_slot1, hashed_slot2].map(|hashed_slot| (hashed_slot, U256::from(2))), + ), + )]), + }) + .unwrap(); + assert!(witness.contains_key(&state_root)); + for node in multiproof.account_subtree.values() { + assert_eq!(witness.get(&keccak256(node)), Some(node)); + } + for node in multiproof.storages.iter().flat_map(|(_, storage)| storage.subtree.values()) { + assert_eq!(witness.get(&keccak256(node)), Some(node)); + } +} diff --git a/crates/trie/trie/src/witness.rs b/crates/trie/trie/src/witness.rs index c042a0d82..3238047c7 100644 --- a/crates/trie/trie/src/witness.rs +++ b/crates/trie/trie/src/witness.rs @@ -218,9 +218,14 @@ where TrieNode::Branch(branch) => { next_path.push(key[path.len()]); let children = branch_node_children(path.clone(), &branch); - for (child_path, node_hash) in children { + for (child_path, value) in children { if !key.starts_with(&child_path) { - trie_nodes.insert(child_path, Either::Left(node_hash)); + let value = if value.len() < B256::len_bytes() { + Either::Right(value.to_vec()) + } else { + Either::Left(B256::from_slice(&value[1..])) + }; + trie_nodes.insert(child_path, value); } } } @@ -311,8 +316,13 @@ where match TrieNode::decode(&mut &node[..])? { TrieNode::Branch(branch) => { let children = branch_node_children(path, &branch); - for (child_path, branch_hash) in children { - hash_builder.add_branch(child_path, branch_hash, false); + for (child_path, value) in children { + if value.len() < B256::len_bytes() { + hash_builder.add_leaf(child_path, value); + } else { + let hash = B256::from_slice(&value[1..]); + hash_builder.add_branch(child_path, hash, false); + } } break } @@ -342,14 +352,14 @@ where } /// Returned branch node children with keys in order. -fn branch_node_children(prefix: Nibbles, node: &BranchNode) -> Vec<(Nibbles, B256)> { +fn branch_node_children(prefix: Nibbles, node: &BranchNode) -> Vec<(Nibbles, &[u8])> { let mut children = Vec::with_capacity(node.state_mask.count_ones() as usize); let mut stack_ptr = node.as_ref().first_child_index(); for index in CHILD_INDEX_RANGE { if node.state_mask.is_bit_set(index) { let mut child_path = prefix.clone(); child_path.push(index); - children.push((child_path, B256::from_slice(&node.stack[stack_ptr][1..]))); + children.push((child_path, &node.stack[stack_ptr][..])); stack_ptr += 1; } }