Skip to content

Commit

Permalink
[resharding] Initialize flat storage during resharding (#9382)
Browse files Browse the repository at this point in the history
It turns out that we were completely ignoring flat storage during resharding and we didn't really have any tests to capture this.

Flat storage was not being written to when we were splitting a shard during resharding. This PR initializes the flat storage in flat storage manager.

Please look at #9418 and #9424 for more context.

Future work
- Clean up work to merge the different implementations of flat storage initialization during state sync and resharding
- Update the tests to better reflect catch up with should automatically handle updating the flat storage of the child shards. Current tests don't handle that and so we need to disable checking flat storage.
- Once this is in place, merge PR #9335
  • Loading branch information
Shreyan Gupta authored Aug 15, 2023
1 parent a456ee6 commit 9e5794d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 27 deletions.
108 changes: 82 additions & 26 deletions chain/chain/src/resharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
/// build_state_for_split_shards_preprocessing and build_state_for_split_shards_postprocessing are handled
/// by the client_actor while the heavy resharding build_state_for_split_shards is done by SyncJobsActor
/// so as to not affect client.
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;

use itertools::Itertools;
use near_chain_primitives::error::Error;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout};
use near_primitives::state_part::PartId;
use near_primitives::syncing::{get_num_state_parts, STATE_PART_MEMORY_LIMIT};
use near_primitives::types::chunk_extra::ChunkExtra;
use near_primitives::types::{AccountId, ShardId, StateRoot};
use near_store::flat::{
store_helper, BlockInfo, FlatStorageManager, FlatStorageReadyStatus, FlatStorageStatus,
};
use near_store::split_state::get_delayed_receipts;
use near_store::{ShardTries, ShardUId, Trie};

use near_store::{ShardTries, ShardUId, Store, Trie};
use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use tracing::debug;

use crate::types::RuntimeAdapter;
Expand Down Expand Up @@ -81,23 +83,41 @@ fn apply_delayed_receipts<'a>(
Ok(new_state_roots)
}

// function to set up flat storage status to Ready after a resharding event
// TODO(resharding) : Consolidate this with setting up flat storage during state sync logic
fn set_flat_storage_state(
store: Store,
flat_storage_manager: &FlatStorageManager,
shard_uid: ShardUId,
block_info: BlockInfo,
) -> Result<(), Error> {
let mut store_update = store.store_update();
store_helper::set_flat_storage_status(
&mut store_update,
shard_uid,
FlatStorageStatus::Ready(FlatStorageReadyStatus { flat_head: block_info }),
);
store_update.commit()?;
flat_storage_manager.create_flat_storage_for_shard(shard_uid)?;
Ok(())
}

impl Chain {
pub fn build_state_for_split_shards_preprocessing(
&self,
sync_hash: &CryptoHash,
shard_id: ShardId,
state_split_scheduler: &dyn Fn(StateSplitRequest),
) -> Result<(), Error> {
let (epoch_id, next_epoch_id) = {
let block_header = self.get_block_header(sync_hash)?;
(block_header.epoch_id().clone(), block_header.next_epoch_id().clone())
};
let shard_layout = self.epoch_manager.get_shard_layout(&epoch_id)?;
let next_epoch_shard_layout = self.epoch_manager.get_shard_layout(&next_epoch_id)?;
let block_header = self.get_block_header(sync_hash)?;
let shard_layout = self.epoch_manager.get_shard_layout(block_header.epoch_id())?;
let next_epoch_shard_layout =
self.epoch_manager.get_shard_layout(block_header.next_epoch_id())?;
assert_ne!(shard_layout, next_epoch_shard_layout);

let shard_uid = ShardUId::from_shard_id_and_layout(shard_id, &shard_layout);
let prev_hash = *self.get_block_header(sync_hash)?.prev_hash();
let prev_hash = block_header.prev_hash();
let state_root = *self.get_chunk_extra(&prev_hash, &shard_uid)?.state_root();
assert_ne!(shard_layout, next_epoch_shard_layout);

state_split_scheduler(StateSplitRequest {
runtime_adapter: self.runtime_adapter.clone(),
Expand Down Expand Up @@ -173,23 +193,57 @@ impl Chain {
state_roots,
&checked_account_id_to_shard_id,
)?;

Ok(state_roots)
}

pub fn build_state_for_split_shards_postprocessing(
&mut self,
sync_hash: &CryptoHash,
state_roots: Result<HashMap<ShardUId, StateRoot>, Error>,
state_roots: HashMap<ShardUId, StateRoot>,
) -> Result<(), Error> {
let prev_hash = *self.get_block_header(sync_hash)?.prev_hash();
let block_header = self.get_block_header(sync_hash)?;
let prev_hash = block_header.prev_hash();

let child_shard_uids = state_roots.keys().collect_vec();
self.initialize_flat_storage(&prev_hash, &child_shard_uids)?;

let mut chain_store_update = self.mut_store().store_update();
for (shard_uid, state_root) in state_roots? {
for (shard_uid, state_root) in state_roots {
// here we store the state roots in chunk_extra in the database for later use
let chunk_extra = ChunkExtra::new_with_only_state_root(&state_root);
chain_store_update.save_chunk_extra(&prev_hash, &shard_uid, chunk_extra);
debug!(target:"chain", "Finish building split state for shard {:?} {:?} {:?} ", shard_uid, prev_hash, state_root);
}
chain_store_update.commit()
chain_store_update.commit()?;

Ok(())
}

// Here we iterate over all the child shards and initialize flat storage for them by calling set_flat_storage_state
// Note that this function is called on the current_block which is the first block the next epoch.
// We set the flat_head as the prev_block as after resharding, the state written to flat storage corresponds to the
// state as of prev_block, and that's the convention that we follow.
fn initialize_flat_storage(
&self,
prev_hash: &CryptoHash,
child_shard_uids: &[&ShardUId],
) -> Result<(), Error> {
let prev_block_header = self.get_block_header(prev_hash)?;
let prev_block_info = BlockInfo {
hash: *prev_block_header.hash(),
prev_hash: *prev_block_header.prev_hash(),
height: prev_block_header.height(),
};

// create flat storage for child shards
if let Some(flat_storage_manager) = self.runtime_adapter.get_flat_storage_manager() {
for shard_uid in child_shard_uids {
let store = self.runtime_adapter.store().clone();
set_flat_storage_state(store, &flat_storage_manager, **shard_uid, prev_block_info)?;
}
}
Ok(())
}
}

Expand Down Expand Up @@ -398,25 +452,27 @@ mod tests {
state_roots: &HashMap<ShardUId, StateRoot>,
account_id_to_shard_id: &dyn Fn(&AccountId) -> ShardUId,
) {
// check that the 4 tries combined to the orig trie
let trie_items =
get_trie_nodes_except_delayed_receipts(tries, &ShardUId::single_shard(), state_root);
let trie_items_by_shard: HashMap<_, _> = state_roots
// Get trie items before resharding and split them by account shard
let trie_items_before_resharding =
get_trie_items_except_delayed_receipts(tries, &ShardUId::single_shard(), state_root);

let trie_items_after_resharding: HashMap<_, _> = state_roots
.iter()
.map(|(&shard_uid, state_root)| {
(shard_uid, get_trie_nodes_except_delayed_receipts(tries, &shard_uid, state_root))
(shard_uid, get_trie_items_except_delayed_receipts(tries, &shard_uid, state_root))
})
.collect();

let mut expected_trie_items_by_shard: HashMap<_, _> =
state_roots.iter().map(|(&shard_uid, _)| (shard_uid, vec![])).collect();
for item in trie_items {
for item in trie_items_before_resharding {
let account_id = parse_account_id_from_raw_key(&item.0).unwrap().unwrap();
let shard_uid: ShardUId = account_id_to_shard_id(&account_id);
expected_trie_items_by_shard.get_mut(&shard_uid).unwrap().push(item);
}
assert_eq!(trie_items_by_shard, expected_trie_items_by_shard);
assert_eq!(expected_trie_items_by_shard, trie_items_after_resharding);

// check that the new tries combined to the orig trie for delayed receipts
let receipts_from_split_states: HashMap<_, _> = state_roots
.iter()
.map(|(&shard_uid, state_root)| {
Expand All @@ -434,7 +490,7 @@ mod tests {
assert_eq!(expected_receipts_by_shard, receipts_from_split_states);
}

fn get_trie_nodes_except_delayed_receipts(
fn get_trie_items_except_delayed_receipts(
tries: &ShardTries,
shard_uid: &ShardUId,
state_root: &StateRoot,
Expand Down
2 changes: 1 addition & 1 deletion chain/client/src/sync/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ impl StateSync {
let result = self.split_state_roots.remove(&shard_id);
let mut shard_sync_done = false;
if let Some(state_roots) = result {
chain.build_state_for_split_shards_postprocessing(&sync_hash, state_roots)?;
chain.build_state_for_split_shards_postprocessing(&sync_hash, state_roots?)?;
*shard_sync_download =
ShardSyncDownload { downloads: vec![], status: ShardSyncStatus::StateSyncDone };
shard_sync_done = true;
Expand Down

0 comments on commit 9e5794d

Please sign in to comment.