Skip to content

Commit

Permalink
[Runtime Epoch Split] (4/n) Make ShardTracker accessible from Runtime…
Browse files Browse the repository at this point in the history
…WithEpochManagerAdapter.
  • Loading branch information
robin-near committed Mar 21, 2023
1 parent 76551b0 commit d28ac14
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 56 additions & 23 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::{Arc, RwLock, Weak};

use borsh::{BorshDeserialize, BorshSerialize};

use near_epoch_manager::shard_tracker::{ShardTracker, TrackedConfig};
use near_epoch_manager::{EpochManagerAdapter, RngSeed};
use near_primitives::sandbox::state_patch::SandboxStatePatch;
use near_primitives::state_part::PartId;
Expand Down Expand Up @@ -858,6 +859,45 @@ impl EpochManagerAdapter for KeyValueRuntime {
Ok(())
}
}

fn cares_about_shard_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_id);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}

fn cares_about_shard_next_epoch_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let chunk_producers = self
.get_chunk_producers((epoch_valset.1 + 1) % self.validators_by_valset.len(), shard_id);
for validator in chunk_producers {
if validator.account_id() == account_id {
return Ok(true);
}
}
Ok(false)
}
}

impl RuntimeAdapter for KeyValueRuntime {
Expand Down Expand Up @@ -936,19 +976,12 @@ impl RuntimeAdapter for KeyValueRuntime {
if self.tracks_all_shards {
return true;
}
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let chunk_producers = self.get_chunk_producers(epoch_valset.1, shard_id);
if let Some(account_id) = account_id {
for validator in chunk_producers {
if validator.account_id() == account_id {
return true;
}
}
self.cares_about_shard_from_prev_block(parent_hash, account_id, shard_id)
.unwrap_or(false)
} else {
false
}
false
}

fn will_care_about_shard(
Expand All @@ -961,20 +994,12 @@ impl RuntimeAdapter for KeyValueRuntime {
if self.tracks_all_shards {
return true;
}
// This `unwrap` here tests that in all code paths we check that the epoch exists before
// we check if we care about a shard. Please do not remove the unwrap, fix the logic of
// the calling function.
let epoch_valset = self.get_epoch_and_valset(*parent_hash).unwrap();
let chunk_producers = self
.get_chunk_producers((epoch_valset.1 + 1) % self.validators_by_valset.len(), shard_id);
if let Some(account_id) = account_id {
for validator in chunk_producers {
if validator.account_id() == account_id {
return true;
}
}
self.cares_about_shard_next_epoch_from_prev_block(parent_hash, account_id, shard_id)
.unwrap_or(false)
} else {
false
}
false
}

fn validate_tx(
Expand Down Expand Up @@ -1433,4 +1458,12 @@ impl RuntimeWithEpochManagerAdapter for KeyValueRuntime {
fn epoch_manager_adapter_arc(&self) -> Arc<dyn EpochManagerAdapter> {
self.myself.upgrade().unwrap()
}
fn shard_tracker(&self) -> ShardTracker {
let config = if self.tracks_all_shards {
TrackedConfig::AllShards
} else {
TrackedConfig::new_empty()
};
ShardTracker::new(config, self.epoch_manager_adapter_arc())
}
}
2 changes: 2 additions & 0 deletions chain/chain/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use borsh::{BorshDeserialize, BorshSerialize};
use chrono::DateTime;
use chrono::Utc;
use near_epoch_manager::shard_tracker::ShardTracker;
use near_primitives::sandbox::state_patch::SandboxStatePatch;
use num_rational::Rational32;

Expand Down Expand Up @@ -579,6 +580,7 @@ pub trait RuntimeAdapter: Send + Sync {
pub trait RuntimeWithEpochManagerAdapter: RuntimeAdapter + EpochManagerAdapter {
fn epoch_manager_adapter(&self) -> &dyn EpochManagerAdapter;
fn epoch_manager_adapter_arc(&self) -> Arc<dyn EpochManagerAdapter>;
fn shard_tracker(&self) -> ShardTracker;
}

/// The last known / checked height and time when we have processed it.
Expand Down
38 changes: 38 additions & 0 deletions chain/epoch-manager/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,20 @@ pub trait EpochManagerAdapter: Send + Sync {
block_height: BlockHeight,
approvals: &[Option<Signature>],
) -> Result<(), Error>;

fn cares_about_shard_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError>;

fn cares_about_shard_next_epoch_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError>;
}

/// A technical plumbing trait to conveniently implement [`EpochManagerAdapter`]
Expand Down Expand Up @@ -896,4 +910,28 @@ impl<T: HasEpochMangerHandle + Send + Sync> EpochManagerAdapter for T {
Ok(())
}
}

fn cares_about_shard_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
let epoch_manager = self.read();
epoch_manager.cares_about_shard_from_prev_block(parent_hash, account_id, shard_id)
}

fn cares_about_shard_next_epoch_from_prev_block(
&self,
parent_hash: &CryptoHash,
account_id: &AccountId,
shard_id: ShardId,
) -> Result<bool, EpochError> {
let epoch_manager = self.read();
epoch_manager.cares_about_shard_next_epoch_from_prev_block(
parent_hash,
account_id,
shard_id,
)
}
}
1 change: 1 addition & 0 deletions chain/epoch-manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod adapter;
mod proposals;
mod reward_calculator;
mod shard_assignment;
pub mod shard_tracker;
pub mod test_utils;
#[cfg(test)]
mod tests;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use crate::append_only_map::AppendOnlyMap;
use std::sync::Arc;

use crate::EpochManagerAdapter;
use near_cache::SyncLruCache;
use near_chain_configs::ClientConfig;
use near_epoch_manager::EpochManagerHandle;
use near_primitives::errors::EpochError;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::account_id_to_shard_id;
use near_primitives::types::{AccountId, EpochId, ShardId};

#[derive(Clone)]
pub enum TrackedConfig {
Accounts(Vec<AccountId>),
AllShards,
Expand All @@ -31,17 +34,21 @@ type BitMask = Vec<bool>;
/// Tracker that tracks shard ids and accounts. Right now, it only supports two modes
/// TrackedConfig::Accounts(accounts): track the shards where `accounts` belong to
/// TrackedConfig::AllShards: track all shards
#[derive(Clone)]
pub struct ShardTracker {
tracked_config: TrackedConfig,
/// Stores shard tracking information by epoch, only useful if TrackedState == Accounts
tracking_shards: AppendOnlyMap<EpochId, BitMask>,
/// Epoch manager that for given block hash computes the epoch id.
epoch_manager: EpochManagerHandle,
tracking_shards_cache: Arc<SyncLruCache<EpochId, BitMask>>,
epoch_manager: Arc<dyn EpochManagerAdapter>,
}

impl ShardTracker {
pub fn new(tracked_config: TrackedConfig, epoch_manager: EpochManagerHandle) -> Self {
ShardTracker { tracked_config, tracking_shards: AppendOnlyMap::new(), epoch_manager }
pub fn new(tracked_config: TrackedConfig, epoch_manager: Arc<dyn EpochManagerAdapter>) -> Self {
ShardTracker {
tracked_config,
tracking_shards_cache: Arc::new(SyncLruCache::new(1024)),
epoch_manager,
}
}

fn tracks_shard_at_epoch(
Expand All @@ -51,9 +58,8 @@ impl ShardTracker {
) -> Result<bool, EpochError> {
match &self.tracked_config {
TrackedConfig::Accounts(tracked_accounts) => {
let epoch_manager = self.epoch_manager.read();
let shard_layout = epoch_manager.get_shard_layout(epoch_id)?;
let tracking_mask = self.tracking_shards.get_or_insert(epoch_id, || {
let shard_layout = self.epoch_manager.get_shard_layout(epoch_id)?;
let tracking_mask = self.tracking_shards_cache.get_or_put(epoch_id.clone(), |_| {
let mut tracking_mask = vec![false; shard_layout.num_shards() as usize];
for account_id in tracked_accounts {
let shard_id = account_id_to_shard_id(account_id, &shard_layout);
Expand All @@ -68,10 +74,7 @@ impl ShardTracker {
}

fn tracks_shard(&self, shard_id: ShardId, prev_hash: &CryptoHash) -> Result<bool, EpochError> {
let epoch_id = {
let epoch_manager = self.epoch_manager.read();
epoch_manager.get_epoch_id_from_prev_block(prev_hash)?
};
let epoch_id = self.epoch_manager.get_epoch_id_from_prev_block(prev_hash)?;
self.tracks_shard_at_epoch(shard_id, &epoch_id)
}

Expand All @@ -85,12 +88,10 @@ impl ShardTracker {
// TODO: fix these unwrap_or here and handle error correctly. The current behavior masks potential errors and bugs
// https://github.com/near/nearcore/issues/4936
if let Some(account_id) = account_id {
let account_cares_about_shard = {
let epoch_manager = self.epoch_manager.read();
epoch_manager
.cares_about_shard_from_prev_block(parent_hash, account_id, shard_id)
.unwrap_or(false)
};
let account_cares_about_shard = self
.epoch_manager
.cares_about_shard_from_prev_block(parent_hash, account_id, shard_id)
.unwrap_or(false);
if !is_me {
return account_cares_about_shard;
} else if account_cares_about_shard {
Expand All @@ -113,8 +114,7 @@ impl ShardTracker {
) -> bool {
if let Some(account_id) = account_id {
let account_cares_about_shard = {
let epoch_manager = self.epoch_manager.read();
epoch_manager
self.epoch_manager
.cares_about_shard_next_epoch_from_prev_block(parent_hash, account_id, shard_id)
.unwrap_or(false)
};
Expand All @@ -133,9 +133,9 @@ impl ShardTracker {
mod tests {
use super::{account_id_to_shard_id, ShardTracker};
use crate::shard_tracker::TrackedConfig;
use crate::test_utils::hash_range;
use crate::{EpochManager, EpochManagerHandle, RewardCalculator};
use near_crypto::{KeyType, PublicKey};
use near_epoch_manager::test_utils::hash_range;
use near_epoch_manager::{EpochManager, EpochManagerHandle, RewardCalculator};
use near_primitives::epoch_manager::block_info::BlockInfo;
use near_primitives::epoch_manager::{AllEpochConfig, EpochConfig};
use near_primitives::hash::CryptoHash;
Expand All @@ -147,6 +147,7 @@ mod tests {
use near_store::test_utils::create_test_store;
use num_rational::Ratio;
use std::collections::HashSet;
use std::sync::Arc;

const DEFAULT_TOTAL_SUPPLY: u128 = 1_000_000_000_000;

Expand Down Expand Up @@ -184,7 +185,7 @@ mod tests {
num_seconds_per_year: 1000000,
};
EpochManager::new(
store,
store.into(),
AllEpochConfig::new(use_production_config, initial_epoch_config),
genesis_protocol_version,
reward_calculator,
Expand Down Expand Up @@ -255,7 +256,8 @@ mod tests {
let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, num_shards, false);
let shard_layout = epoch_manager.read().get_shard_layout(&EpochId::default()).unwrap();
let tracked_accounts = vec!["test1".parse().unwrap(), "test2".parse().unwrap()];
let tracker = ShardTracker::new(TrackedConfig::Accounts(tracked_accounts), epoch_manager);
let tracker =
ShardTracker::new(TrackedConfig::Accounts(tracked_accounts), Arc::new(epoch_manager));
let mut total_tracked_shards = HashSet::new();
total_tracked_shards
.insert(account_id_to_shard_id(&"test1".parse().unwrap(), &shard_layout));
Expand All @@ -276,7 +278,7 @@ mod tests {
fn test_track_all_shards() {
let num_shards = 4;
let epoch_manager = get_epoch_manager(PROTOCOL_VERSION, num_shards, false);
let tracker = ShardTracker::new(TrackedConfig::AllShards, epoch_manager);
let tracker = ShardTracker::new(TrackedConfig::AllShards, Arc::new(epoch_manager));
let total_tracked_shards: HashSet<_> = (0..num_shards).collect();

assert_eq!(
Expand All @@ -296,7 +298,7 @@ mod tests {
let tracked_accounts = vec!["near".parse().unwrap(), "zoo".parse().unwrap()];
let tracker = ShardTracker::new(
TrackedConfig::Accounts(tracked_accounts.clone()),
epoch_manager.clone(),
Arc::new(epoch_manager.clone()),
);

let h = hash_range(8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use near_chunks::metrics::PARTIAL_ENCODED_CHUNK_FORWARD_CACHED_WITHOUT_HEADER;
use near_client::test_utils::{create_chunk_with_transactions, TestEnv};
use near_client::ProcessTxResponse;
use near_crypto::{InMemorySigner, KeyType, Signer};
use near_epoch_manager::shard_tracker::TrackedConfig;
use near_network::shards_manager::ShardsManagerRequestFromNetwork;
use near_network::types::{NetworkRequests, PeerManagerMessageRequest};
use near_o11y::testonly::init_test_logger;
Expand All @@ -25,7 +26,7 @@ use near_primitives::version::{ProtocolFeature, ProtocolVersion};
use near_primitives::views::FinalExecutionStatus;
use near_store::test_utils::create_test_store;
use nearcore::config::GenesisExt;
use nearcore::{TrackedConfig, NEAR_BASE};
use nearcore::NEAR_BASE;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::collections::HashSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ use near_chain_configs::Genesis;
use near_client::test_utils::TestEnv;
use near_client::ProcessTxResponse;
use near_crypto::{InMemorySigner, KeyType, Signer};
use near_epoch_manager::shard_tracker::TrackedConfig;
use near_primitives::account::{AccessKey, AccessKeyPermission, FunctionCallPermission};
use near_primitives::errors::{ActionsValidationError, InvalidTxError};
use near_primitives::hash::CryptoHash;
use near_primitives::runtime::config_store::RuntimeConfigStore;
use near_primitives::transaction::{Action, AddKeyAction, Transaction};
use near_store::test_utils::create_test_store;
use nearcore::config::GenesisExt;
use nearcore::TrackedConfig;
use std::path::Path;
use std::sync::Arc;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use near_async::messaging::CanSend;
use near_chain::{ChainGenesis, Provenance, RuntimeWithEpochManagerAdapter};
use near_chain_configs::Genesis;
use near_client::test_utils::TestEnv;
use near_epoch_manager::shard_tracker::TrackedConfig;
use near_network::{
shards_manager::ShardsManagerRequestFromNetwork,
types::{NetworkRequests, PeerManagerMessageRequest},
Expand All @@ -15,7 +16,7 @@ use near_primitives::{
types::{AccountId, EpochId, ShardId},
};
use near_store::test_utils::create_test_store;
use nearcore::{config::GenesisExt, TrackedConfig};
use nearcore::config::GenesisExt;
use tracing::log::debug;

struct AdversarialBehaviorTestData {
Expand Down
Loading

0 comments on commit d28ac14

Please sign in to comment.