Skip to content

Commit

Permalink
flat storage for state part
Browse files Browse the repository at this point in the history
  • Loading branch information
Longarithm committed May 12, 2023
1 parent 62921b9 commit e16c6a2
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 35 deletions.
8 changes: 7 additions & 1 deletion core/primitives/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use borsh::{BorshDeserialize, BorshSerialize};

use near_primitives_core::hash::{hash, CryptoHash};
use near_primitives_core::hash::{CryptoHash, hash};

/// State value reference. Used to charge fees for value length before retrieving the value itself.
#[derive(BorshSerialize, BorshDeserialize, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -61,3 +61,9 @@ mod tests {
assert_eq!(value_ref.hash, hash(&value));
}
}

#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq)]
pub enum FlatStateValue {
Ref(ValueRef),
// TODO(8243): add variant here for the inlined value
}
11 changes: 11 additions & 0 deletions core/primitives/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use near_crypto::PublicKey;
pub use near_primitives_core::types::*;
use once_cell::sync::Lazy;
use std::sync::Arc;
use std::time::Duration;

/// Hash used by to store state root.
pub type StateRoot = CryptoHash;
Expand Down Expand Up @@ -997,3 +998,13 @@ pub struct StateChangesForShard {
pub shard_id: ShardId,
pub state_changes: Vec<RawStateChangesWithTrieKey>,
}

#[derive(Default)]
pub struct Stats {
pub boundaries_read_duration: Duration,
pub value_refs_read_duration: Duration,
pub values_read_duration: Duration,
pub trie_updates_gen_duration: Duration,
pub final_trie_gen_duration: Duration,
pub internal_get_duration: Duration,
}
5 changes: 5 additions & 0 deletions core/store/src/flat/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ impl FlatStorage {

Ok(())
}

pub fn get_head_hash(&self) -> CryptoHash {
let guard = self.0.write().expect(super::POISONED_LOCK_ERR);
guard.flat_head.hash
}
}

#[cfg(feature = "protocol_feature_flat_state")]
Expand Down
6 changes: 3 additions & 3 deletions core/store/src/flat/store_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use crate::{Store, StoreUpdate};
use near_primitives::errors::StorageError;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::{ShardLayout, ShardUId};
use near_primitives::state::ValueRef;
use near_primitives::state::{FlatStateValue, ValueRef};

use super::delta::{FlatStateDelta, FlatStateDeltaMetadata};
use super::types::{FlatStateValue, FlatStorageStatus};
use super::types::FlatStorageStatus;

/// Prefixes for keys in `FlatStateMisc` DB column.
pub const FLAT_STATE_HEAD_KEY_PREFIX: &[u8; 4] = b"HEAD";
Expand Down Expand Up @@ -89,7 +89,7 @@ pub fn remove_all_deltas(store_update: &mut StoreUpdate, shard_uid: ShardUId) {
store_update.delete_range(FlatStateColumn::DeltaMetadata.to_db_col(), &key_from, &key_to);
}

fn encode_flat_state_db_key(shard_uid: ShardUId, key: &[u8]) -> Vec<u8> {
pub fn encode_flat_state_db_key(shard_uid: ShardUId, key: &[u8]) -> Vec<u8> {
let mut buffer = vec![];
buffer.extend_from_slice(&shard_uid.to_bytes());
buffer.extend_from_slice(key);
Expand Down
7 changes: 0 additions & 7 deletions core/store/src/flat/types.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
use borsh::{BorshDeserialize, BorshSerialize};
use near_primitives::errors::StorageError;
use near_primitives::hash::CryptoHash;
use near_primitives::state::ValueRef;
use near_primitives::types::BlockHeight;

#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq)]
pub enum FlatStateValue {
Ref(ValueRef),
// TODO(8243): add variant here for the inlined value
}

#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, PartialEq, Eq)]
pub struct BlockInfo {
pub hash: CryptoHash,
Expand Down
3 changes: 3 additions & 0 deletions core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,13 +839,16 @@ impl Trie {
{
let mut memory = NodesStorage::new();
let mut root_node = self.move_node_to_mutable(&mut memory, &self.root)?;
let mut changes_num = 0;
for (key, value) in changes {
changes_num += 1;
let key = NibbleSlice::new(&key);
root_node = match value {
Some(arr) => self.insert(&mut memory, root_node, key, arr),
None => self.delete(&mut memory, root_node, key),
}?;
}
eprintln!("FS KV pairs = {}", changes_num);

#[cfg(test)]
{
Expand Down
157 changes: 154 additions & 3 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

use borsh::BorshDeserialize;

use near_primitives::challenge::{PartialState, StateItem};
use near_primitives::state_part::PartId;
use near_primitives::types::StateRoot;
use near_primitives::types::{ShardId, StateRoot, Stats};
use tracing::error;

use crate::flat::FlatStateChanges;
use crate::flat::store_helper::encode_flat_state_db_key;
use crate::flat::{store_helper, FlatStateChanges};
use crate::trie::iterator::TrieTraversalItem;
use crate::trie::nibble_slice::NibbleSlice;
use crate::trie::trie_storage::TrieMemoryPartialStorage;
use crate::trie::{
ApplyStatePartResult, NodeHandle, RawTrieNodeWithSize, TrieNode, TrieNodeWithSize,
};
use crate::{PartialStorage, StorageError, Trie, TrieChanges};
use near_primitives::contract::ContractCode;
use near_primitives::state::ValueRef;
use near_primitives::hash::{hash, CryptoHash};
use near_primitives::shard_layout::{ShardLayout, ShardUId};
use near_primitives::state::{FlatStateValue, ValueRef};
use near_primitives::state_record::is_contract_code_key;

impl Trie {
Expand All @@ -38,6 +45,145 @@ impl Trie {
Ok(trie_nodes)
}

// why is this so complicated?
pub fn path_to_flat_state_key(&self, shard_uid: ShardUId, path: &[u8]) -> Option<Vec<u8>> {
if path.len() == 1 && path[0] == 16 {
return Some(ShardUId::next_shard_prefix(&shard_uid.to_bytes()).to_vec());
}
let mut key = Vec::new();
for i in 0..path.len() {
if i % 2 == 0 {
// we think it is fine because of trie properties and it preserves lex order
let rem = if i + 1 == path.len() { 0 } else { path[i + 1] };
key.push(path[i] * 16 + rem);
}
}
Some(encode_flat_state_db_key(shard_uid, &key))
}

pub fn get_raw_bytes(&self, value_ref: &ValueRef) -> Result<Option<Vec<u8>>, StorageError> {
let hash = value_ref.hash;
self.storage.retrieve_raw_bytes(&hash).map(|bytes| Some(bytes.to_vec()))
}

fn internal_get_state_part(&self, part_id: PartId) -> Result<PartialState, StorageError> {
self.visit_nodes_for_state_part(part_id).expect("Failed to visit nodes for part");
let memory_storage = self.storage.as_partial_storage().unwrap();
Ok(memory_storage.partial_state())
}

pub fn get_trie_nodes_for_part_with_flat_storage(
&self,
// ?? we should be able to retrieve shard context here
shard_layout: ShardLayout,
shard_id: ShardId,
part_id: PartId,
) -> Result<(PartialState, Stats), StorageError> {
eprintln!("part {:?}", part_id);
assert!(self.storage.as_caching_storage().is_some());
let store = self.storage.as_caching_storage().unwrap().store.clone();
let with_recording = self.recording_reads();

let boundaries_read_start = Instant::now();
let path_begin = with_recording.find_path_for_part_boundary(part_id.idx, part_id.total)?;
let path_end =
with_recording.find_path_for_part_boundary(part_id.idx + 1, part_id.total)?;
// Make sure to touch the boundary node.
with_recording.special_visit(part_id)?;
let boundaries_read_duration = boundaries_read_start.elapsed();

// TODO: we should also go through all the receipts if the queue is not empty.
// ?? we already store everything in flat storage

let recorded = with_recording.recorded_storage().unwrap();
let trie_nodes = recorded.nodes;

let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &shard_layout);
let key_begin = self.path_to_flat_state_key(shard_uid, &path_begin);
let key_end = self.path_to_flat_state_key(shard_uid, &path_end);

let value_refs_read_start = Instant::now();
let flat_state_iter = store_helper::iter_flat_state_entries(
shard_layout,
shard_id,
&store,
key_begin.as_ref(),
key_end.as_ref(),
);
// collect now for simplicity of measurements
let flat_state_value_refs: Vec<_> = flat_state_iter.collect();
let value_refs_read_duration = value_refs_read_start.elapsed();

let values_read_start = Instant::now();
let with_values: Vec<_> = flat_state_value_refs
.iter()
.map(|(k, v)| {
// eprintln!("{:?} {} {:?}", k, v.len(), v);
let FlatStateValue::Ref(value_ref) = FlatStateValue::try_from_slice(&v).unwrap();
let raw_bytes = self.get_raw_bytes(&value_ref).unwrap().unwrap();
(k.to_vec(), Some(raw_bytes))
})
.collect();
let values_read_duration = values_read_start.elapsed();

let trie_updates_gen_start = Instant::now();
// Now let's create a new trie with all these values included.
let in_memory_trie =
Trie::new(Box::new(TrieMemoryPartialStorage::default()), StateRoot::new(), None);
// This will generate all the intermediate nodes.
let trie_updates = in_memory_trie.update(with_values.into_iter()).unwrap();
let trie_updates_gen_duration = trie_updates_gen_start.elapsed();

let final_trie_gen_start = Instant::now();
// Hashes from Trie reads to exclude for flat storage effectiveness computation.
let hashes_from_trie: Vec<_> = trie_nodes.0.iter().map(|entry| hash(entry)).collect();

// Now let's create another storage with everything included.
let mut all_nodes: HashMap<CryptoHash, Arc<[u8]>> = HashMap::new();
// Adding nodes from the 'boundary'
all_nodes.extend(trie_nodes.0.iter().map(|entry| (hash(entry), entry.clone())));
all_nodes.extend(
trie_updates
.insertions
.iter()
.map(|entry| (entry.hash().clone(), entry.payload().to_vec().into())),
);

let final_trie = Trie::new(
Box::new(TrieMemoryPartialStorage {
recorded_storage: all_nodes,
visited_nodes: RefCell::default(),
}),
self.root,
None,
);
let final_trie_gen_duration = final_trie_gen_start.elapsed();

let internal_get_start = Instant::now();
let result = final_trie.internal_get_state_part(part_id);
if let Ok(partial_state) = &result {
let total_num = partial_state.0.len();
let from_flat_storage = partial_state
.0
.iter()
.filter(|entry| !hashes_from_trie.contains(&hash(*entry)))
.count();
eprintln!("from FS = {}, total = {}", from_flat_storage, total_num);
};
let internal_get_duration = internal_get_start.elapsed();
Ok((
result?,
Stats {
boundaries_read_duration,
value_refs_read_duration,
values_read_duration,
trie_updates_gen_duration,
final_trie_gen_duration,
internal_get_duration,
},
))
}

/// Assume we lay out all trie nodes in dfs order visiting children after the parent.
/// We take all node sizes (memory_usage_direct()) and take all nodes intersecting with
/// [size_start, size_end) interval, also all nodes necessary to prove it and some
Expand All @@ -54,6 +200,11 @@ impl Trie {
tracing::debug!(
target: "state_parts",
num_nodes = nodes_list.len());
self.special_visit(part_id)
}

pub fn special_visit(&self, part_id: PartId) -> Result<(), StorageError> {
let path_end = self.find_path_for_part_boundary(part_id.idx + 1, part_id.total)?;

// Extra nodes for compatibility with the previous version of computing state parts
if part_id.idx + 1 != part_id.total {
Expand Down
22 changes: 22 additions & 0 deletions core/store/src/trie/trie_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use lru::LruCache;
use near_o11y::log_assert;
use near_o11y::metrics::prometheus;
use near_o11y::metrics::prometheus::core::{GenericCounter, GenericGauge};
use near_primitives::challenge::PartialState;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::ShardUId;
use near_primitives::types::{ShardId, TrieCacheMode, TrieNodesCount};
Expand Down Expand Up @@ -349,6 +350,7 @@ impl TrieStorage for TrieRecordingStorage {

/// Storage for validating recorded partial storage.
/// visited_nodes are to validate that partial storage doesn't contain unnecessary nodes.
#[derive(Default)]
pub struct TrieMemoryPartialStorage {
pub(crate) recorded_storage: HashMap<CryptoHash, Arc<[u8]>>,
pub(crate) visited_nodes: RefCell<HashSet<CryptoHash>>,
Expand All @@ -372,6 +374,26 @@ impl TrieStorage for TrieMemoryPartialStorage {
}
}

impl TrieMemoryPartialStorage {
pub fn partial_state(&self) -> PartialState {
let touched_nodes = self.visited_nodes.borrow();
let mut nodes: Vec<_> =
self.recorded_storage
.iter()
.filter_map(|(node_hash, value)| {
if touched_nodes.contains(node_hash) {
Some(value.clone())
} else {
None
}
})
.collect();

nodes.sort();
PartialState(nodes)
}
}

/// Storage for reading State nodes and values from DB which caches reads.
pub struct TrieCachingStorage {
pub(crate) store: Store,
Expand Down
Loading

0 comments on commit e16c6a2

Please sign in to comment.