Skip to content

Commit

Permalink
fix: create intermediate nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Oct 28, 2024
1 parent cf984af commit cc94266
Showing 1 changed file with 339 additions and 8 deletions.
347 changes: 339 additions & 8 deletions src/hash_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,18 @@ impl HashBuilder {
let preceding_len = self.groups.len().saturating_sub(1);

let common_prefix_len = succeeding.common_prefix_length(current.as_slice());
let len = cmp::max(preceding_len, common_prefix_len);
assert!(len < current.len(), "len {} current.len {}", len, current.len());

// guard against accessing beyond the current key length
if common_prefix_len >= current.len() {
return;
}

// calculate len ensuring we don't exceed current key boundaries
let len = if current.is_empty() {
0
} else {
cmp::min(cmp::max(preceding_len, common_prefix_len), current.len() - 1)
};

trace!(
target: "trie::hash_builder",
Expand Down Expand Up @@ -231,10 +241,38 @@ impl HashBuilder {
self.resize_masks(current.len());
}

let mut len_from = len;
if !succeeding.is_empty() || preceding_exists {
len_from += 1;
// handle intermediate nodes if needed, when current path is more
// than 1 nibble longer than common prefix
let needs_update = len + 1 < current.len() - 1;
if needs_update {
for intermediate_len in (len + 1)..current.len() {
let intermediate_path = current.slice(..intermediate_len);

if self
.updated_branch_nodes
.as_ref()
.map_or(true, |nodes| !nodes.contains_key(&intermediate_path))
{
let next_nibble = current[intermediate_len];
let state_mask = TrieMask::from_nibble(next_nibble);
let tree_mask =
if self.stored_in_database { state_mask } else { TrieMask::default() };

let node = BranchNodeCompact::new(
state_mask,
tree_mask,
TrieMask::default(),
Vec::new(),
None,
);

if let Some(ref mut nodes) = self.updated_branch_nodes {
nodes.insert(intermediate_path, node);
}
}
}
}
let len_from = if !succeeding.is_empty() || preceding_exists { len + 1 } else { len };
trace!(target: "trie::hash_builder", "skipping {len_from} nibbles");

// The key without the common prefix
Expand Down Expand Up @@ -372,7 +410,18 @@ impl HashBuilder {
self.hash_masks[parent_index] |= TrieMask::from_nibble(current[parent_index]);
}

let store_in_db_trie = !self.tree_masks[len].is_empty() || !self.hash_masks[len].is_empty();
// first set the state mask - must contain all child positions
let my_state_mask = self.groups[len];

// then determine tree mask - must be subset of state mask
let my_tree_mask = if self.stored_in_database {
// only set tree mask bits for positions that exist in state mask
my_state_mask & self.tree_masks[len]
} else {
TrieMask::default()
};

let store_in_db_trie = !my_state_mask.is_empty() || self.stored_in_database;
if store_in_db_trie {
if len > 0 {
let parent_index = len - 1;
Expand All @@ -382,8 +431,8 @@ impl HashBuilder {
if self.updated_branch_nodes.is_some() {
let common_prefix = current.slice(..len);
let node = BranchNodeCompact::new(
self.groups[len],
self.tree_masks[len],
my_state_mask,
my_tree_mask,
self.hash_masks[len],
children,
(len == 0).then(|| self.current_root()),
Expand Down Expand Up @@ -432,6 +481,7 @@ mod tests {
use alloc::collections::BTreeMap;
use alloy_primitives::{b256, hex, U256};
use alloy_rlp::Encodable;
use std::collections::HashSet;

// Hashes the keys, RLP encodes the values, compares the trie builder with the upstream root.
fn assert_hashed_trie_root<'a, I, K>(iter: I)
Expand Down Expand Up @@ -611,4 +661,285 @@ mod tests {
assert_eq!(hb.root(), expected);
assert_eq!(hb2.root(), expected);
}

#[test]
fn test_intermediate_nodes_creation() {
let mut hb = HashBuilder::default().with_updates(true);

let deep_key = Nibbles::from_nibbles_unchecked(vec![0, 1, 2, 3]);

let value = alloy_rlp::encode(&U256::from(42)).to_vec();
hb.add_leaf(deep_key.clone(), &value);

// calculate root to ensure all nodes are created
let _root = hb.root();

let (_, updates) = hb.split();

let expected_paths = [
Nibbles::from_nibbles_unchecked(vec![0]),
Nibbles::from_nibbles_unchecked(vec![0, 1]),
Nibbles::from_nibbles_unchecked(vec![0, 1, 2]),
];

for path in &expected_paths {
assert!(
updates.contains_key(path),
"Missing node at path: {:?}\nAvailable paths: {:?}",
path,
updates.keys().collect::<Vec<_>>()
);
}

let first_node = updates.get(&expected_paths[0]).unwrap();
let expected_state_mask = 0b0000_0000_0000_0010;
assert_eq!(
first_node.state_mask.get(),
expected_state_mask,
"First level node state mask mismatch. Expected: {:016b}, Got: {:016b}",
expected_state_mask,
first_node.state_mask.get()
);
assert!(first_node.tree_mask.is_subset_of(first_node.state_mask));

let second_node = updates.get(&expected_paths[1]).unwrap();
let expected_state_mask = 0b0000_0000_0000_0100;
assert_eq!(
second_node.state_mask.get(),
expected_state_mask,
"Second level node state mask mismatch. Expected: {:016b}, Got: {:016b}",
expected_state_mask,
second_node.state_mask.get()
);
assert!(second_node.tree_mask.is_subset_of(second_node.state_mask));
}

#[test]
fn test_tree_mask_is_subset_of_state_mask() {
let mut hb = HashBuilder::default().with_updates(true);

let key1 = Nibbles::from_nibbles_unchecked(vec![1, 2, 3]);
let key2 = Nibbles::from_nibbles_unchecked(vec![1, 2, 4]);

let value = alloy_rlp::encode(&U256::from(42)).to_vec();
hb.add_leaf(key1, &value);
hb.add_leaf(key2, &value);

// calculate root to ensure all nodes are created
let _root = hb.root();

let (_, updates) = hb.split();

for (_, node) in updates.iter() {
assert!(
node.tree_mask.is_subset_of(node.state_mask),
"Tree mask {:016b} is not a subset of state mask {:016b}",
node.tree_mask.get(),
node.state_mask.get()
);
}
}

#[test]
fn test_stored_in_database_flag_impact() {
let mut hb = HashBuilder::default().with_updates(true);
let key1 = Nibbles::from_nibbles_unchecked(vec![0, 1]);
let value = B256::with_last_byte(1);

hb.stored_in_database = true;
hb.add_branch(key1.clone(), value, true);

// calculate root to ensure all nodes are created
let _root = hb.root();
let (_, updates1) = hb.split();

if let Some(node) = updates1.get(&key1) {
assert!(
node.tree_mask.get() > 0,
"Tree mask should be non-zero for stored node at {:?}",
key1
);
assert!(
node.tree_mask.is_subset_of(node.state_mask),
"Tree mask should be subset of state mask for path: {:?}\n\
state_mask: {:016b}\n\
tree_mask: {:016b}",
key1,
node.state_mask.get(),
node.tree_mask.get()
);
}

let mut hb = HashBuilder::default().with_updates(true);
hb.stored_in_database = false;
hb.add_branch(key1.clone(), value, false);

// calculate root to ensure all nodes are created
let _root = hb.root();
let (_, updates2) = hb.split();

if let Some(node) = updates2.get(&key1) {
assert_eq!(
node.tree_mask.get(),
0,
"Tree mask should be zero for non-stored node at {:?}",
key1
);
assert!(
node.tree_mask.is_subset_of(node.state_mask),
"Tree mask should be subset of state mask for path: {:?}\n\
state_mask: {:016b}\n\
tree_mask: {:016b}",
key1,
node.state_mask.get(),
node.tree_mask.get()
);
}
}

#[test]
fn test_complex_trie_structure() {
let mut hb = HashBuilder::default().with_updates(true);

let data = BTreeMap::from([
(
hex!("1234000000000000000000000000000000000000000000000000000000000000").to_vec(),
Vec::new(),
),
(
hex!("1235000000000000000000000000000000000000000000000000000000000000").to_vec(),
Vec::new(),
),
(
hex!("1245000000000000000000000000000000000000000000000000000000000000").to_vec(),
Vec::new(),
),
(
hex!("1345000000000000000000000000000000000000000000000000000000000000").to_vec(),
Vec::new(),
),
]);

// Add each value to the trie
for (key, val) in &data {
let nibbles = Nibbles::unpack(key);
hb.add_leaf(nibbles, val.as_slice());
}

let root = hb.root();
let (_, updates) = hb.split();

let expected_paths = [
(vec![0x1], 0b0000_0000_0000_1100),
(vec![0x1, 0x2], 0b0000_0000_0001_1000),
(vec![0x1, 0x3], 0b0000_0000_0001_0000),
(vec![0x1, 0x2, 0x3], 0b0000_0000_0011_0000),
(vec![0x1, 0x2, 0x4], 0b0000_0000_0010_0000),
(vec![0x1, 0x3, 0x4], 0b0000_0000_0010_0000),
];

for (path_nibbles, expected_mask) in expected_paths {
let path = Nibbles::from_nibbles_unchecked(path_nibbles);
let node = updates
.get(&path)
.unwrap_or_else(|| panic!("Missing expected node at path: {:?}", path));
assert_eq!(
node.state_mask.get(),
expected_mask,
"Wrong state mask for path {:?}. Expected {:016b}, got {:016b}",
path,
expected_mask,
node.state_mask.get()
);
assert!(node.tree_mask.is_subset_of(node.state_mask));
}

for (key, node) in updates.iter() {
assert!(
node.tree_mask.is_subset_of(node.state_mask),
"Tree mask not subset of state mask for key {:?}",
key
);

if key.len() < 3 {
assert!(
node.state_mask.get() > 0,
"State mask should not be empty for non-leaf node at {:?}",
key
);
}
}

assert_eq!(root, triehash_trie_root(data));
}

fn verify_tree_mask_invariant(
nodes: &HashMap<Nibbles, BranchNodeCompact>,
removed: &HashSet<Nibbles>,
) -> bool {
for (path, branch_node) in nodes {
let child_paths = (0..16)
.filter(|&pos| (branch_node.tree_mask.get() & (1 << pos)) != 0)
.map(|pos| {
let mut child_path = path.clone();
child_path.push(pos as u8);
child_path
})
.collect::<Vec<_>>();

for child_path in child_paths {
if !nodes.contains_key(&child_path) && !removed.contains(&child_path) {
println!(
"Missing child node at path: {:?} for parent: {:?} with mask: {:016b}",
child_path,
path,
branch_node.tree_mask.get()
);
return false;
}
}
}
true
}

#[test]
fn test_regression_reth_issue_12129() {
let mut hb = HashBuilder::default().with_updates(true);

let path1 = Nibbles::from_nibbles_unchecked(hex!("0b"));
let path2 = Nibbles::from_nibbles_unchecked(hex!("0b0605"));
let path3 = Nibbles::from_nibbles_unchecked(hex!("0b060502"));

hb.add_leaf(path1.clone(), &[1]);
hb.add_leaf(path2.clone(), &[2]);
hb.add_leaf(path3.clone(), &[3]);

// calculate root to ensure all nodes are created
let _root = hb.root();
let (_, updates) = hb.split();

let intermediate_paths = [
Nibbles::from_nibbles_unchecked(hex!("0b")),
Nibbles::from_nibbles_unchecked(hex!("0b06")),
Nibbles::from_nibbles_unchecked(hex!("0b0605")),
];

for path in intermediate_paths.iter() {
assert!(updates.contains_key(path), "Missing intermediate node at path: {:?}", path);
}

for (path, node) in updates.iter() {
assert!(
node.tree_mask.is_subset_of(node.state_mask),
"Tree mask not subset of state mask for path: {:?}\n\
state_mask: {:016b}\n\
tree_mask: {:016b}",
path,
node.state_mask.get(),
node.tree_mask.get()
);
}

assert!(verify_tree_mask_invariant(&updates, &HashSet::new()));
}
}

0 comments on commit cc94266

Please sign in to comment.