Skip to content

Commit

Permalink
Enable in-memory trie when state sync (#10820)
Browse files Browse the repository at this point in the history
**Issue**: #10564

**Summary**
Adds logic to load / unload in-memory tries that works with state sync.
Enables in-memory trie with single shard tracking.

**Changes**
- Add optional `state_root` parameter for memtrie loading logic - it's
needed when we cannot read the state root from chunk extra.
- Add `load_mem_tries_for_tracked_shards` config parameter.
- Add methods for loading / unloading in-memory tries.
- Remove obsolete tries from memory before each new state sync.

**Follow up tasks**
- Add shard assignment shuffling every epoch for StatelessNet:
#10845.
- Make sure the logic works well with resharding, and state is fully
GC-ed, add integration tests:
#10844.
  • Loading branch information
staffik authored Mar 27, 2024
1 parent e8635c6 commit d4d1b82
Show file tree
Hide file tree
Showing 24 changed files with 204 additions and 58 deletions.
17 changes: 16 additions & 1 deletion chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ impl Chain {
chain_config: ChainConfig,
snapshot_callbacks: Option<SnapshotCallbacks>,
apply_chunks_spawner: Arc<dyn AsyncComputationSpawner>,
validator_account_id: Option<&AccountId>,
) -> Result<Chain, Error> {
// Get runtime initial state and create genesis block out of it.
let state_roots = get_genesis_state_roots(runtime_adapter.store())?
Expand Down Expand Up @@ -508,7 +509,19 @@ impl Chain {
let tip = chain_store.head()?;
let shard_uids: Vec<_> =
epoch_manager.get_shard_layout(&tip.epoch_id)?.shard_uids().collect();
runtime_adapter.load_mem_tries_on_startup(&shard_uids)?;
let tracked_shards: Vec<_> = shard_uids
.iter()
.filter(|shard_uid| {
shard_tracker.care_about_shard(
validator_account_id,
&tip.prev_block_hash,
shard_uid.shard_id(),
true,
)
})
.cloned()
.collect();
runtime_adapter.load_mem_tries_on_startup(&tracked_shards)?;

info!(target: "chain", "Init: header head @ #{} {}; block head @ #{} {}",
header_head.height, header_head.last_block_hash,
Expand Down Expand Up @@ -2770,6 +2783,8 @@ impl Chain {
);
store_update.commit()?;
flat_storage_manager.create_flat_storage_for_shard(shard_uid).unwrap();
// Flat storage is ready, load memtrie if it is enabled.
self.runtime_adapter.load_mem_trie_on_catchup(&shard_uid, &chunk.prev_state_root())?;
}

let mut height = shard_state_header.chunk_height_included();
Expand Down
1 change: 1 addition & 0 deletions chain/chain/src/resharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ impl Chain {

let child_shard_uids = state_roots.keys().cloned().collect_vec();
self.initialize_flat_storage(&prev_hash, &child_shard_uids)?;
// TODO(resharding) #10844 Load in-memory trie if needed.

let mut chain_store_update = self.mut_chain_store().store_update();
for (shard_uid, state_root) in state_roots {
Expand Down
26 changes: 24 additions & 2 deletions chain/chain/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1243,8 +1243,30 @@ impl RuntimeAdapter for NightshadeRuntime {
Ok(epoch_manager.will_shard_layout_change(parent_hash)?)
}

fn load_mem_tries_on_startup(&self, shard_uids: &[ShardUId]) -> Result<(), StorageError> {
self.tries.load_mem_tries_for_enabled_shards(shard_uids)
fn load_mem_tries_on_startup(&self, tracked_shards: &[ShardUId]) -> Result<(), StorageError> {
self.tries.load_mem_tries_for_enabled_shards(tracked_shards)
}

fn load_mem_trie_on_catchup(
&self,
shard_uid: &ShardUId,
state_root: &StateRoot,
) -> Result<(), StorageError> {
if !self.get_tries().trie_config().load_mem_tries_for_tracked_shards {
return Ok(());
}
// It should not happen that memtrie is already loaded for a shard
// for which we just did state sync.
debug_assert!(!self.tries.is_mem_trie_loaded(shard_uid));
self.tries.load_mem_trie(shard_uid, Some(*state_root))
}

fn retain_mem_tries(&self, shard_uids: &[ShardUId]) {
self.tries.retain_mem_tries(shard_uids)
}

fn unload_mem_trie(&self, shard_uid: &ShardUId) {
self.tries.unload_mem_trie(shard_uid)
}
}

Expand Down
1 change: 1 addition & 0 deletions chain/chain/src/runtime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,7 @@ fn get_test_env_with_chain_and_pool() -> (TestEnv, Chain, TransactionPool) {
ChainConfig::test(),
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();

Expand Down
1 change: 1 addition & 0 deletions chain/chain/src/store_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ mod tests {
ChainConfig::test(),
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();
(
Expand Down
2 changes: 2 additions & 0 deletions chain/chain/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub fn get_chain_with_epoch_length_and_num_shards(
ChainConfig::test(),
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap()
}
Expand Down Expand Up @@ -164,6 +165,7 @@ pub fn setup_with_tx_validity_period(
ChainConfig::test(),
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();

Expand Down
14 changes: 13 additions & 1 deletion chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,19 @@ impl RuntimeAdapter for KeyValueRuntime {
Ok(vec![])
}

fn load_mem_tries_on_startup(&self, _shard_uids: &[ShardUId]) -> Result<(), StorageError> {
fn load_mem_tries_on_startup(&self, _tracked_shards: &[ShardUId]) -> Result<(), StorageError> {
Ok(())
}

fn load_mem_trie_on_catchup(
&self,
_shard_uid: &ShardUId,
_state_root: &StateRoot,
) -> Result<(), StorageError> {
Ok(())
}

fn retain_mem_tries(&self, _shard_uids: &[ShardUId]) {}

fn unload_mem_trie(&self, _shard_uid: &ShardUId) {}
}
17 changes: 16 additions & 1 deletion chain/chain/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,22 @@ pub trait RuntimeAdapter: Send + Sync {
/// Loads in-memory tries upon startup. The given shard_uids are possible candidates to load,
/// but which exact shards to load depends on configuration. This may only be called when flat
/// storage is ready.
fn load_mem_tries_on_startup(&self, shard_uids: &[ShardUId]) -> Result<(), StorageError>;
fn load_mem_tries_on_startup(&self, tracked_shards: &[ShardUId]) -> Result<(), StorageError>;

/// Loads in-memory trie upon catchup, if it is enabled.
/// Requires state root because `ChunkExtra` is not available at the time mem-trie is being loaded.
fn load_mem_trie_on_catchup(
&self,
shard_uid: &ShardUId,
state_root: &StateRoot,
) -> Result<(), StorageError>;

/// Retains in-memory tries for given shards, i.e. unload tries from memory for shards that are NOT
/// in the given list. Should be called to unload obsolete tries from memory.
fn retain_mem_tries(&self, shard_uids: &[ShardUId]);

/// Unload trie from memory for given shard.
fn unload_mem_trie(&self, shard_uid: &ShardUId);
}

/// The last known / checked height and time when we have processed it.
Expand Down
27 changes: 26 additions & 1 deletion chain/chunks/src/logic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use near_chain::{types::EpochManagerAdapter, validate::validate_chunk_proofs, Chain, ChainStore};
use near_chain::{
types::EpochManagerAdapter, validate::validate_chunk_proofs, BlockHeader, Chain, ChainStore,
};
use near_chunks_primitives::Error;
use near_epoch_manager::shard_tracker::ShardTracker;
use near_primitives::{
Expand Down Expand Up @@ -46,6 +48,29 @@ pub fn cares_about_shard_this_or_next_epoch(
|| shard_tracker.will_care_about_shard(account_id, parent_hash, shard_id, is_me)
}

pub fn get_shards_cares_about_this_or_next_epoch(
account_id: Option<&AccountId>,
is_me: bool,
block_header: &BlockHeader,
shard_tracker: &ShardTracker,
epoch_manager: &dyn EpochManagerAdapter,
) -> Vec<ShardId> {
epoch_manager
.shard_ids(&block_header.epoch_id())
.unwrap()
.into_iter()
.filter(|&shard_id| {
cares_about_shard_this_or_next_epoch(
account_id,
block_header.prev_hash(),
shard_id,
is_me,
shard_tracker,
)
})
.collect()
}

pub fn chunk_needs_to_be_fetched_from_archival(
chunk_prev_block_hash: &CryptoHash,
header_head: &CryptoHash,
Expand Down
23 changes: 21 additions & 2 deletions chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use near_chain_configs::{ClientConfig, LogSummaryStyle, UpdateableClientConfig};
use near_chunks::adapter::ShardsManagerRequestFromClient;
use near_chunks::client::ShardedTransactionPool;
use near_chunks::logic::{
cares_about_shard_this_or_next_epoch, decode_encoded_chunk, persist_chunk,
cares_about_shard_this_or_next_epoch, decode_encoded_chunk,
get_shards_cares_about_this_or_next_epoch, persist_chunk,
};
use near_chunks::ShardsManager;
use near_client_primitives::debug::ChunkProduction;
Expand Down Expand Up @@ -275,6 +276,7 @@ impl Client {
chain_config.clone(),
snapshot_callbacks,
async_computation_spawner.clone(),
validator_signer.as_ref().map(|x| x.validator_id()),
)?;
// Create flat storage or initiate migration to flat storage.
let flat_storage_creator = FlatStorageCreator::new(
Expand Down Expand Up @@ -2333,13 +2335,15 @@ impl Client {
let _span = debug_span!(target: "sync", "run_catchup").entered();
let mut notify_state_sync = false;
let me = &self.validator_signer.as_ref().map(|x| x.validator_id().clone());

for (sync_hash, state_sync_info) in self.chain.chain_store().iterate_state_sync_infos()? {
assert_eq!(sync_hash, state_sync_info.epoch_tail_hash);
let network_adapter = self.network_adapter.clone();

let shards_to_split = self.get_shards_to_split(sync_hash, &state_sync_info, me)?;
let state_sync_timeout = self.config.state_sync_timeout;
let epoch_id = self.chain.get_block(&sync_hash)?.header().epoch_id().clone();
let block_header = self.chain.get_block(&sync_hash)?.header().clone();
let epoch_id = block_header.epoch_id();

let (state_sync, shards_to_split, blocks_catch_up_state) =
self.catchup_state_syncs.entry(sync_hash).or_insert_with(|| {
Expand Down Expand Up @@ -2371,6 +2375,21 @@ impl Client {
.epoch_manager
.get_shard_layout(&epoch_id)
.expect("Cannot get shard layout");

// Make sure mem-tries for shards we do not care about are unloaded before we start a new state sync.
let shards_cares_this_or_next_epoch = get_shards_cares_about_this_or_next_epoch(
me.as_ref(),
true,
&block_header,
&self.shard_tracker,
self.epoch_manager.as_ref(),
);
let shard_uids: Vec<_> = shards_cares_this_or_next_epoch
.iter()
.map(|id| self.epoch_manager.shard_id_to_uid(*id, &epoch_id).unwrap())
.collect();
self.runtime_adapter.retain_mem_tries(&shard_uids);

for &shard_id in &tracking_shards {
let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &shard_layout);
match self.state_sync_adapter.clone().read() {
Expand Down
25 changes: 8 additions & 17 deletions chain/client/src/client_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use near_chain::{
use near_chain_configs::{ClientConfig, LogSummaryStyle};
use near_chain_primitives::error::EpochErrorResultToChainError;
use near_chunks::client::ShardsManagerResponse;
use near_chunks::logic::cares_about_shard_this_or_next_epoch;
use near_chunks::logic::get_shards_cares_about_this_or_next_epoch;
use near_client_primitives::types::{
Error, GetClientConfig, GetClientConfigError, GetNetworkInfo, NetworkInfoResponse,
StateSyncStatus, Status, StatusError, StatusSyncInfo, SyncStatus,
Expand Down Expand Up @@ -1543,22 +1543,13 @@ impl ClientActions {
unwrap_and_report!(self.client.chain.get_block_header(&sync_hash));
let prev_hash = *block_header.prev_hash();
let epoch_id = block_header.epoch_id().clone();
let shards_to_sync: Vec<_> = self
.client
.epoch_manager
.shard_ids(&epoch_id)
.unwrap()
.into_iter()
.filter(|&shard_id| {
cares_about_shard_this_or_next_epoch(
me.as_ref(),
&prev_hash,
shard_id,
true,
&self.client.shard_tracker,
)
})
.collect();
let shards_to_sync = get_shards_cares_about_this_or_next_epoch(
me.as_ref(),
true,
&block_header,
&self.client.shard_tracker,
self.client.epoch_manager.as_ref(),
);

let use_colour =
matches!(self.client.config.log_summary_style, LogSummaryStyle::Colored);
Expand Down
1 change: 1 addition & 0 deletions chain/client/src/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,7 @@ mod tests {
ChainConfig::test(),
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();

Expand Down
3 changes: 3 additions & 0 deletions chain/client/src/sync_jobs_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ impl SyncJobsActions {
}

pub fn handle_apply_state_parts_request(&mut self, msg: ApplyStatePartsRequest) {
// Unload mem-trie (in case it is still loaded) before we apply state parts.
msg.runtime_adapter.unload_mem_trie(&msg.shard_uid);

let shard_id = msg.shard_uid.shard_id as ShardId;
match self.clear_flat_state(&msg) {
Err(err) => {
Expand Down
3 changes: 3 additions & 0 deletions chain/client/src/test_utils/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn setup(
},
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();
let genesis_block = chain.get_block(&chain.genesis().hash().clone()).unwrap();
Expand Down Expand Up @@ -259,6 +260,7 @@ pub fn setup_only_view(
},
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();

Expand Down Expand Up @@ -1030,6 +1032,7 @@ pub fn setup_synchronous_shards_manager(
}, // irrelevant
None,
Arc::new(RayonAsyncComputationSpawner),
None,
)
.unwrap();
let chain_head = chain.head().unwrap();
Expand Down
6 changes: 3 additions & 3 deletions core/store/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ pub struct StoreConfig {
/// TODO(#9511): This does not automatically survive resharding. We may need to figure out a
/// strategy for that.
pub load_mem_tries_for_shards: Vec<ShardUId>,
/// If true, load mem tries for all shards; this has priority over `load_mem_tries_for_shards`.
pub load_mem_tries_for_all_shards: bool,
/// If true, load mem trie for each shard being tracked; this has priority over `load_mem_tries_for_shards`.
pub load_mem_tries_for_tracked_shards: bool,

/// Path where to create RocksDB checkpoints during database migrations or
/// `false` to disable that feature.
Expand Down Expand Up @@ -258,7 +258,7 @@ impl Default for StoreConfig {
// It will speed up processing of shards where it is enabled, but
// requires more RAM and takes several minutes on startup.
load_mem_tries_for_shards: Default::default(),
load_mem_tries_for_all_shards: false,
load_mem_tries_for_tracked_shards: false,

migration_snapshot: Default::default(),

Expand Down
2 changes: 1 addition & 1 deletion core/store/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl TestTriesBuilder {
let tries = ShardTries::new(
store,
TrieConfig {
load_mem_tries_for_all_shards: self.enable_in_memory_tries,
load_mem_tries_for_tracked_shards: self.enable_in_memory_tries,
..Default::default()
},
&shard_uids,
Expand Down
5 changes: 3 additions & 2 deletions core/store/src/trie/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ pub struct TrieConfig {
pub sweat_prefetch_senders: Vec<AccountId>,
/// List of shards we will load into memory.
pub load_mem_tries_for_shards: Vec<ShardUId>,
pub load_mem_tries_for_all_shards: bool,
/// Whether mem-trie should be loaded for each tracked shard.
pub load_mem_tries_for_tracked_shards: bool,
}

impl TrieConfig {
Expand All @@ -58,7 +59,7 @@ impl TrieConfig {
}
}
this.load_mem_tries_for_shards = config.load_mem_tries_for_shards.clone();
this.load_mem_tries_for_all_shards = config.load_mem_tries_for_all_shards;
this.load_mem_tries_for_tracked_shards = config.load_mem_tries_for_tracked_shards;

this
}
Expand Down
Loading

0 comments on commit d4d1b82

Please sign in to comment.