diff --git a/src/hash_builder/mod.rs b/src/hash_builder/mod.rs index 8ddd860..a461272 100644 --- a/src/hash_builder/mod.rs +++ b/src/hash_builder/mod.rs @@ -200,8 +200,18 @@ impl HashBuilder { let preceding_len = self.groups.len().saturating_sub(1); let common_prefix_len = succeeding.common_prefix_length(current.as_slice()); - let len = cmp::max(preceding_len, common_prefix_len); - assert!(len < current.len(), "len {} current.len {}", len, current.len()); + + // guard against accessing beyond the current key length + if common_prefix_len >= current.len() { + return; + } + + // calculate len ensuring we don't exceed current key boundaries + let len = if current.is_empty() { + 0 + } else { + cmp::min(cmp::max(preceding_len, common_prefix_len), current.len() - 1) + }; trace!( target: "trie::hash_builder", @@ -231,10 +241,38 @@ impl HashBuilder { self.resize_masks(current.len()); } - let mut len_from = len; - if !succeeding.is_empty() || preceding_exists { - len_from += 1; + // handle intermediate nodes if needed, when current path is more + // than 1 nibble longer than common prefix + let needs_update = len + 1 < current.len() - 1; + if needs_update { + for intermediate_len in (len + 1)..current.len() { + let intermediate_path = current.slice(..intermediate_len); + + if self + .updated_branch_nodes + .as_ref() + .map_or(true, |nodes| !nodes.contains_key(&intermediate_path)) + { + let next_nibble = current[intermediate_len]; + let state_mask = TrieMask::from_nibble(next_nibble); + let tree_mask = + if self.stored_in_database { state_mask } else { TrieMask::default() }; + + let node = BranchNodeCompact::new( + state_mask, + tree_mask, + TrieMask::default(), + Vec::new(), + None, + ); + + if let Some(ref mut nodes) = self.updated_branch_nodes { + nodes.insert(intermediate_path, node); + } + } + } } + let len_from = if !succeeding.is_empty() || preceding_exists { len + 1 } else { len }; trace!(target: "trie::hash_builder", "skipping {len_from} nibbles"); // The key without the common prefix @@ -372,7 +410,18 @@ impl HashBuilder { self.hash_masks[parent_index] |= TrieMask::from_nibble(current[parent_index]); } - let store_in_db_trie = !self.tree_masks[len].is_empty() || !self.hash_masks[len].is_empty(); + // first set the state mask - must contain all child positions + let my_state_mask = self.groups[len]; + + // then determine tree mask - must be subset of state mask + let my_tree_mask = if self.stored_in_database { + // only set tree mask bits for positions that exist in state mask + my_state_mask & self.tree_masks[len] + } else { + TrieMask::default() + }; + + let store_in_db_trie = !my_state_mask.is_empty() || self.stored_in_database; if store_in_db_trie { if len > 0 { let parent_index = len - 1; @@ -382,8 +431,8 @@ impl HashBuilder { if self.updated_branch_nodes.is_some() { let common_prefix = current.slice(..len); let node = BranchNodeCompact::new( - self.groups[len], - self.tree_masks[len], + my_state_mask, + my_tree_mask, self.hash_masks[len], children, (len == 0).then(|| self.current_root()), @@ -611,4 +660,276 @@ mod tests { assert_eq!(hb.root(), expected); assert_eq!(hb2.root(), expected); } + + #[test] + fn test_intermediate_nodes_creation() { + let mut hb = HashBuilder::default().with_updates(true); + + let deep_key = Nibbles::from_nibbles_unchecked(vec![0, 1, 2, 3]); + + let value = alloy_rlp::encode(U256::from(42)).to_vec(); + hb.add_leaf(deep_key.clone(), &value); + + // calculate root to ensure all nodes are created + let _root = hb.root(); + + let (_, updates) = hb.split(); + + let expected_paths = [ + Nibbles::from_nibbles_unchecked(vec![0]), + Nibbles::from_nibbles_unchecked(vec![0, 1]), + Nibbles::from_nibbles_unchecked(vec![0, 1, 2]), + ]; + + for path in &expected_paths { + assert!( + updates.contains_key(path), + "Missing node at path: {:?}\nAvailable paths: {:?}", + path, + updates.keys().collect::>() + ); + } + + let first_node = updates.get(&expected_paths[0]).unwrap(); + let expected_state_mask = 0b0000_0000_0000_0010; + assert_eq!( + first_node.state_mask.get(), + expected_state_mask, + "First level node state mask mismatch. Expected: {:016b}, Got: {:016b}", + expected_state_mask, + first_node.state_mask.get() + ); + assert!(first_node.tree_mask.is_subset_of(first_node.state_mask)); + + let second_node = updates.get(&expected_paths[1]).unwrap(); + let expected_state_mask = 0b0000_0000_0000_0100; + assert_eq!( + second_node.state_mask.get(), + expected_state_mask, + "Second level node state mask mismatch. Expected: {:016b}, Got: {:016b}", + expected_state_mask, + second_node.state_mask.get() + ); + assert!(second_node.tree_mask.is_subset_of(second_node.state_mask)); + } + + #[test] + fn test_tree_mask_is_subset_of_state_mask() { + let mut hb = HashBuilder::default().with_updates(true); + + let key1 = Nibbles::from_nibbles_unchecked(vec![1, 2, 3]); + let key2 = Nibbles::from_nibbles_unchecked(vec![1, 2, 4]); + + let value = alloy_rlp::encode(U256::from(42)).to_vec(); + hb.add_leaf(key1, &value); + hb.add_leaf(key2, &value); + + // calculate root to ensure all nodes are created + let _root = hb.root(); + + let (_, updates) = hb.split(); + + for (_, node) in updates.iter() { + assert!( + node.tree_mask.is_subset_of(node.state_mask), + "Tree mask {:016b} is not a subset of state mask {:016b}", + node.tree_mask.get(), + node.state_mask.get() + ); + } + } + + #[test] + fn test_stored_in_database_flag_impact() { + let mut hb = HashBuilder::default().with_updates(true); + let key1 = Nibbles::from_nibbles_unchecked(vec![0, 1]); + let value = B256::with_last_byte(1); + + hb.stored_in_database = true; + hb.add_branch(key1.clone(), value, true); + + // calculate root to ensure all nodes are created + let _root = hb.root(); + let (_, updates1) = hb.split(); + + if let Some(node) = updates1.get(&key1) { + assert!( + node.tree_mask.get() > 0, + "Tree mask should be non-zero for stored node at {:?}", + key1 + ); + assert!( + node.tree_mask.is_subset_of(node.state_mask), + "Tree mask should be subset of state mask for path: {:?}\n\ + state_mask: {:016b}\n\ + tree_mask: {:016b}", + key1, + node.state_mask.get(), + node.tree_mask.get() + ); + } + + let mut hb = HashBuilder::default().with_updates(true); + hb.stored_in_database = false; + hb.add_branch(key1.clone(), value, false); + + // calculate root to ensure all nodes are created + let _root = hb.root(); + let (_, updates2) = hb.split(); + + if let Some(node) = updates2.get(&key1) { + assert_eq!( + node.tree_mask.get(), + 0, + "Tree mask should be zero for non-stored node at {:?}", + key1 + ); + assert!( + node.tree_mask.is_subset_of(node.state_mask), + "Tree mask should be subset of state mask for path: {:?}\n\ + state_mask: {:016b}\n\ + tree_mask: {:016b}", + key1, + node.state_mask.get(), + node.tree_mask.get() + ); + } + } + + #[test] + fn test_complex_trie_structure() { + let mut hb = HashBuilder::default().with_updates(true); + + let data = BTreeMap::from([ + ( + hex!("1234000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1235000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1245000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ( + hex!("1345000000000000000000000000000000000000000000000000000000000000").to_vec(), + Vec::new(), + ), + ]); + + // Add each value to the trie + for (key, val) in &data { + let nibbles = Nibbles::unpack(key); + hb.add_leaf(nibbles, val.as_slice()); + } + + let root = hb.root(); + let (_, updates) = hb.split(); + + let expected_paths = [ + (vec![0x1], 0b0000_0000_0000_1100), + (vec![0x1, 0x2], 0b0000_0000_0001_1000), + (vec![0x1, 0x3], 0b0000_0000_0001_0000), + (vec![0x1, 0x2, 0x3], 0b0000_0000_0011_0000), + (vec![0x1, 0x2, 0x4], 0b0000_0000_0010_0000), + (vec![0x1, 0x3, 0x4], 0b0000_0000_0010_0000), + ]; + + for (path_nibbles, expected_mask) in expected_paths { + let path = Nibbles::from_nibbles_unchecked(path_nibbles); + let node = updates + .get(&path) + .unwrap_or_else(|| panic!("Missing expected node at path: {:?}", path)); + assert_eq!( + node.state_mask.get(), + expected_mask, + "Wrong state mask for path {:?}. Expected {:016b}, got {:016b}", + path, + expected_mask, + node.state_mask.get() + ); + assert!(node.tree_mask.is_subset_of(node.state_mask)); + } + + for (key, node) in updates.iter() { + assert!( + node.tree_mask.is_subset_of(node.state_mask), + "Tree mask not subset of state mask for key {:?}", + key + ); + + if key.len() < 3 { + assert!( + node.state_mask.get() > 0, + "State mask should not be empty for non-leaf node at {:?}", + key + ); + } + } + + assert_eq!(root, triehash_trie_root(data)); + } + + fn verify_tree_mask_invariant(nodes: &HashMap) -> bool { + for (path, branch_node) in nodes { + let child_paths = (0..16) + .filter(|&pos| (branch_node.tree_mask.get() & (1 << pos)) != 0) + .map(|pos| { + let mut child_path = path.clone(); + child_path.push(pos as u8); + child_path + }) + .collect::>(); + + for child_path in child_paths { + if !nodes.contains_key(&child_path) { + return false; + } + } + } + true + } + + #[test] + fn test_regression_reth_issue_12129() { + let mut hb = HashBuilder::default().with_updates(true); + + let path1 = Nibbles::from_nibbles_unchecked(hex!("0b")); + let path2 = Nibbles::from_nibbles_unchecked(hex!("0b0605")); + let path3 = Nibbles::from_nibbles_unchecked(hex!("0b060502")); + + hb.add_leaf(path1.clone(), &[1]); + hb.add_leaf(path2.clone(), &[2]); + hb.add_leaf(path3.clone(), &[3]); + + // calculate root to ensure all nodes are created + let _root = hb.root(); + let (_, updates) = hb.split(); + + let intermediate_paths = [ + Nibbles::from_nibbles_unchecked(hex!("0b")), + Nibbles::from_nibbles_unchecked(hex!("0b06")), + Nibbles::from_nibbles_unchecked(hex!("0b0605")), + ]; + + for path in intermediate_paths.iter() { + assert!(updates.contains_key(path), "Missing intermediate node at path: {:?}", path); + } + + for (path, node) in updates.iter() { + assert!( + node.tree_mask.is_subset_of(node.state_mask), + "Tree mask not subset of state mask for path: {:?}\n\ + state_mask: {:016b}\n\ + tree_mask: {:016b}", + path, + node.state_mask.get(), + node.tree_mask.get() + ); + } + + assert!(verify_tree_mask_invariant(&updates)); + } }