Skip to content

Commit

Permalink
[Runtime Epoch Split] (2/n) Narrow EpochManagerAdapter error type to …
Browse files Browse the repository at this point in the history
…EpochError wherever possible. (#8767)

As a reminder of the overall goal, we want to split `RuntimeWithEpochManagerAdapter`, so that any code that needs the runtime will use an `Arc<RuntimeAdapter>`, and any code that uses the epoch manager will use the `EpochManagerHandle`.

We're doing this refactoring bottom-up, i.e. propagating `RuntimeWithEpochManagerAdapter` around at the top-level, but making some lower-level components use `Arc<RuntimeAdapter>` and/or `EpochManagerHandle`.

That means we need to be able to obtain an `Arc<RuntimeAdapter>` and an `EpochManagerHandle` from a `RuntimeWithEpochManagerAdapter`. However, this is not trivial at all:
 1. `KeyValueRuntime`, the implementation of `RuntimeWithEpochManagerAdapter` for testing, does not contain an `EpochManager` at all, so it's not possible to extract an `EpochManagerHandle` from it (which is essentially an arc mutex of `EpochManager`). That means instead of using `EpochManagerHandle`, we need to use `Arc<EpochManagerAdapter>` in the meantime.
 2. Extracting an `Arc<EpochManagerAdapter>` from a `Arc<RuntimeWithEpochManagerAdapter>` is not trivial. Even though `RuntimeWithEpochManagerAdapter` is a trait that extends `EpochManagerAdapter`, trait upcast is not allowed by Rust in general. So we need to resort to a workaround.

This PR addresses an issue arising from (1), that current code that expects an `EpochManagerHandle` uses functions from `EpochManager` which mostly returns `EpochError`. `EpochManagerAdapter`, on the other hand, returns mostly `near_chain_primitive::Error`, so changing to the adapter type breaks existing code. So, in this PR we are changing the error type returned from most of the `EpochManagerAdapter` functions to also be EpochError. This is a pretty painless transition, except in some cases we need to do a two-hop conversion from `EpochError` to chain `Error` and then to some other error (and two-hop conversions cannot be supported by `.into()` in general), so I've defined `.into_chain_error()` to help with the two-hop conversion cases.
  • Loading branch information
robin-near authored Mar 22, 2023
1 parent 6c1a516 commit dc8f61c
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 154 deletions.
10 changes: 10 additions & 0 deletions chain/chain-primitives/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,16 @@ impl From<EpochError> for Error {
}
}

pub trait EpochErrorResultToChainError<T> {
fn into_chain_error(self) -> Result<T, Error>;
}

impl<T> EpochErrorResultToChainError<T> for Result<T, EpochError> {
fn into_chain_error(self: Result<T, EpochError>) -> Result<T, Error> {
self.map_err(|err| err.into())
}
}

impl From<ShardLayoutError> for Error {
fn from(error: ShardLayoutError) -> Self {
match error {
Expand Down
4 changes: 2 additions & 2 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3547,7 +3547,7 @@ impl Chain {
) -> Result<AccountId, Error> {
let head = self.head()?;
let target_height = head.height + horizon - 1;
self.runtime_adapter.get_chunk_producer(epoch_id, target_height, shard_id)
Ok(self.runtime_adapter.get_chunk_producer(epoch_id, target_height, shard_id)?)
}

/// Find a validator that is responsible for a given shard to forward requests to
Expand Down Expand Up @@ -4503,7 +4503,7 @@ impl Chain {
{
let prev_hash = *sync_block.header().prev_hash();
// If sync_hash is not on the Epoch boundary, it's malicious behavior
self.runtime_adapter.is_next_block_epoch_start(&prev_hash)
Ok(self.runtime_adapter.is_next_block_epoch_start(&prev_hash)?)
} else {
Ok(false) // invalid Epoch of sync_hash, possible malicious behavior
}
Expand Down
103 changes: 58 additions & 45 deletions chain/chain/src/test_utils/kv_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl KeyValueRuntime {
self.tracks_all_shards = tracks_all_shards;
}

fn get_block_header(&self, hash: &CryptoHash) -> Result<Option<BlockHeader>, Error> {
fn get_block_header(&self, hash: &CryptoHash) -> Result<Option<BlockHeader>, EpochError> {
let mut headers_cache = self.headers_cache.write().unwrap();
if headers_cache.get(hash).is_some() {
return Ok(Some(headers_cache.get(hash).unwrap().clone()));
Expand All @@ -250,13 +250,13 @@ impl KeyValueRuntime {
fn get_epoch_and_valset(
&self,
prev_hash: CryptoHash,
) -> Result<(EpochId, usize, EpochId), Error> {
) -> Result<(EpochId, usize, EpochId), EpochError> {
if prev_hash == CryptoHash::default() {
return Ok((EpochId(prev_hash), 0, EpochId(prev_hash)));
}
let prev_block_header = self
.get_block_header(&prev_hash)?
.ok_or_else(|| Error::DBNotFoundErr(prev_hash.to_string()))?;
.ok_or_else(|| EpochError::MissingBlock(prev_hash))?;

let mut hash_to_epoch = self.hash_to_epoch.write().unwrap();
let mut hash_to_next_epoch_approvals_req =
Expand Down Expand Up @@ -326,23 +326,23 @@ impl KeyValueRuntime {
self.validators_by_valset[valset].chunk_producers[shard_id as usize].clone()
}

fn get_valset_for_epoch(&self, epoch_id: &EpochId) -> Result<usize, Error> {
fn get_valset_for_epoch(&self, epoch_id: &EpochId) -> Result<usize, EpochError> {
// conveniently here if the prev_hash is passed mistakenly instead of the epoch_hash,
// the `unwrap` will trigger
Ok(*self
.hash_to_valset
.read()
.unwrap()
.get(epoch_id)
.ok_or_else(|| Error::EpochOutOfBounds(epoch_id.clone()))? as usize
.ok_or_else(|| EpochError::EpochOutOfBounds(epoch_id.clone()))? as usize
% self.validators_by_valset.len())
}

pub fn get_chunk_only_producers_for_shard(
&self,
epoch_id: &EpochId,
shard_id: ShardId,
) -> Result<Vec<&ValidatorStake>, Error> {
) -> Result<Vec<&ValidatorStake>, EpochError> {
let valset = self.get_valset_for_epoch(epoch_id)?;
let block_producers = &self.validators_by_valset[valset].block_producers;
let chunk_producers = &self.validators_by_valset[valset].chunk_producers[shard_id as usize];
Expand Down Expand Up @@ -377,7 +377,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
self.hash_to_valset.write().unwrap().contains_key(epoch_id)
}

fn num_shards(&self, _epoch_id: &EpochId) -> Result<ShardId, Error> {
fn num_shards(&self, _epoch_id: &EpochId) -> Result<ShardId, EpochError> {
Ok(self.num_shards)
}

Expand All @@ -395,7 +395,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
}
}

fn get_part_owner(&self, epoch_id: &EpochId, part_id: u64) -> Result<AccountId, Error> {
fn get_part_owner(&self, epoch_id: &EpochId, part_id: u64) -> Result<AccountId, EpochError> {
let validators =
&self.get_epoch_block_producers_ordered(epoch_id, &CryptoHash::default())?;
// if we don't use data_parts and total_parts as part of the formula here, the part owner
Expand All @@ -408,19 +408,23 @@ impl EpochManagerAdapter for KeyValueRuntime {
&self,
account_id: &AccountId,
_epoch_id: &EpochId,
) -> Result<ShardId, Error> {
) -> Result<ShardId, EpochError> {
Ok(account_id_to_shard_id(account_id, self.num_shards))
}

fn shard_id_to_uid(&self, shard_id: ShardId, _epoch_id: &EpochId) -> Result<ShardUId, Error> {
fn shard_id_to_uid(
&self,
shard_id: ShardId,
_epoch_id: &EpochId,
) -> Result<ShardUId, EpochError> {
Ok(ShardUId { version: 0, shard_id: shard_id as u32 })
}

fn get_block_info(&self, _hash: &CryptoHash) -> Result<Arc<BlockInfo>, Error> {
fn get_block_info(&self, _hash: &CryptoHash) -> Result<Arc<BlockInfo>, EpochError> {
Ok(Default::default())
}

fn get_epoch_config(&self, _epoch_id: &EpochId) -> Result<EpochConfig, Error> {
fn get_epoch_config(&self, _epoch_id: &EpochId) -> Result<EpochConfig, EpochError> {
Ok(EpochConfig {
epoch_length: 10,
num_block_producer_seats: 2,
Expand All @@ -447,7 +451,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
/// - block producers
/// - chunk producers
/// All the other fields have a hardcoded value or left empty.
fn get_epoch_info(&self, _epoch_id: &EpochId) -> Result<Arc<EpochInfo>, Error> {
fn get_epoch_info(&self, _epoch_id: &EpochId) -> Result<Arc<EpochInfo>, EpochError> {
let validators = self.validators.iter().map(|(_, stake)| stake.clone()).collect();
let mut validator_to_index = HashMap::new();
for (i, (account_id, _)) in self.validators.iter().enumerate() {
Expand Down Expand Up @@ -486,41 +490,44 @@ impl EpochManagerAdapter for KeyValueRuntime {
)))
}

fn get_shard_layout(&self, _epoch_id: &EpochId) -> Result<ShardLayout, Error> {
fn get_shard_layout(&self, _epoch_id: &EpochId) -> Result<ShardLayout, EpochError> {
Ok(ShardLayout::v0(self.num_shards, 0))
}

fn get_shard_config(&self, _epoch_id: &EpochId) -> Result<ShardConfig, Error> {
fn get_shard_config(&self, _epoch_id: &EpochId) -> Result<ShardConfig, EpochError> {
panic!("get_shard_config not implemented for KeyValueRuntime");
}

fn is_next_block_epoch_start(&self, parent_hash: &CryptoHash) -> Result<bool, Error> {
fn is_next_block_epoch_start(&self, parent_hash: &CryptoHash) -> Result<bool, EpochError> {
if parent_hash == &CryptoHash::default() {
return Ok(true);
}
let prev_block_header = self.get_block_header(parent_hash)?.ok_or_else(|| {
Error::Other(format!("Missing block {} when computing the epoch", parent_hash))
})?;
let prev_block_header = self
.get_block_header(parent_hash)?
.ok_or_else(|| EpochError::MissingBlock(*parent_hash))?;
let prev_prev_hash = *prev_block_header.prev_hash();
Ok(self.get_epoch_and_valset(*parent_hash)?.0
!= self.get_epoch_and_valset(prev_prev_hash)?.0)
}

fn get_epoch_id_from_prev_block(&self, parent_hash: &CryptoHash) -> Result<EpochId, Error> {
fn get_epoch_id_from_prev_block(
&self,
parent_hash: &CryptoHash,
) -> Result<EpochId, EpochError> {
Ok(self.get_epoch_and_valset(*parent_hash)?.0)
}

fn get_epoch_height_from_prev_block(
&self,
_prev_block_hash: &CryptoHash,
) -> Result<EpochHeight, Error> {
) -> Result<EpochHeight, EpochError> {
Ok(0)
}

fn get_next_epoch_id_from_prev_block(
&self,
parent_hash: &CryptoHash,
) -> Result<EpochId, Error> {
) -> Result<EpochId, EpochError> {
Ok(self.get_epoch_and_valset(*parent_hash)?.2)
}

Expand All @@ -535,11 +542,11 @@ impl EpochManagerAdapter for KeyValueRuntime {
fn get_shard_layout_from_prev_block(
&self,
_parent_hash: &CryptoHash,
) -> Result<ShardLayout, Error> {
) -> Result<ShardLayout, EpochError> {
Ok(ShardLayout::v0(self.num_shards, 0))
}

fn get_epoch_id(&self, block_hash: &CryptoHash) -> Result<EpochId, Error> {
fn get_epoch_id(&self, block_hash: &CryptoHash) -> Result<EpochId, EpochError> {
let (epoch_id, _, _) = self.get_epoch_and_valset(*block_hash)?;
Ok(epoch_id)
}
Expand All @@ -548,17 +555,17 @@ impl EpochManagerAdapter for KeyValueRuntime {
&self,
epoch_id: &EpochId,
other_epoch_id: &EpochId,
) -> Result<Ordering, Error> {
) -> Result<Ordering, EpochError> {
if epoch_id.0 == other_epoch_id.0 {
return Ok(Ordering::Equal);
}
match (self.get_valset_for_epoch(epoch_id), self.get_valset_for_epoch(other_epoch_id)) {
(Ok(index1), Ok(index2)) => Ok(index1.cmp(&index2)),
_ => Err(Error::EpochOutOfBounds(epoch_id.clone())),
_ => Err(EpochError::EpochOutOfBounds(epoch_id.clone())),
}
}

fn get_epoch_start_height(&self, block_hash: &CryptoHash) -> Result<BlockHeight, Error> {
fn get_epoch_start_height(&self, block_hash: &CryptoHash) -> Result<BlockHeight, EpochError> {
let epoch_id = self.get_epoch_id(block_hash)?;
match self.get_block_header(&epoch_id.0)? {
Some(block_header) => Ok(block_header.height()),
Expand All @@ -569,12 +576,12 @@ impl EpochManagerAdapter for KeyValueRuntime {
fn get_prev_epoch_id_from_prev_block(
&self,
prev_block_hash: &CryptoHash,
) -> Result<EpochId, Error> {
) -> Result<EpochId, EpochError> {
let mut candidate_hash = *prev_block_hash;
loop {
let header = self
.get_block_header(&candidate_hash)?
.ok_or_else(|| Error::DBNotFoundErr(candidate_hash.to_string()))?;
.ok_or_else(|| EpochError::MissingBlock(candidate_hash))?;
candidate_hash = *header.prev_hash();
if self.is_next_block_epoch_start(&candidate_hash)? {
break Ok(self.get_epoch_and_valset(candidate_hash)?.0);
Expand All @@ -593,15 +600,15 @@ impl EpochManagerAdapter for KeyValueRuntime {
&self,
epoch_id: &EpochId,
_last_known_block_hash: &CryptoHash,
) -> Result<Vec<(ValidatorStake, bool)>, Error> {
) -> Result<Vec<(ValidatorStake, bool)>, EpochError> {
let validators = self.get_block_producers(self.get_valset_for_epoch(epoch_id)?);
Ok(validators.iter().map(|x| (x.clone(), false)).collect())
}

fn get_epoch_block_approvers_ordered(
&self,
parent_hash: &CryptoHash,
) -> Result<Vec<(ApprovalStake, bool)>, Error> {
) -> Result<Vec<(ApprovalStake, bool)>, EpochError> {
let (_cur_epoch, cur_valset, next_epoch) = self.get_epoch_and_valset(*parent_hash)?;
let mut validators = self
.get_block_producers(cur_valset)
Expand All @@ -623,7 +630,10 @@ impl EpochManagerAdapter for KeyValueRuntime {
Ok(validators)
}

fn get_epoch_chunk_producers(&self, _epoch_id: &EpochId) -> Result<Vec<ValidatorStake>, Error> {
fn get_epoch_chunk_producers(
&self,
_epoch_id: &EpochId,
) -> Result<Vec<ValidatorStake>, EpochError> {
tracing::warn!("not implemented, returning a dummy value");
Ok(vec![])
}
Expand All @@ -632,7 +642,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
&self,
epoch_id: &EpochId,
height: BlockHeight,
) -> Result<AccountId, Error> {
) -> Result<AccountId, EpochError> {
let validators = self.get_block_producers(self.get_valset_for_epoch(epoch_id)?);
Ok(validators[(height as usize) % validators.len()].account_id().clone())
}
Expand All @@ -642,7 +652,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
epoch_id: &EpochId,
height: BlockHeight,
shard_id: ShardId,
) -> Result<AccountId, Error> {
) -> Result<AccountId, EpochError> {
let valset = self.get_valset_for_epoch(epoch_id)?;
let chunk_producers = self.get_chunk_producers(valset, shard_id);
let index = (shard_id + height + 1) as usize % chunk_producers.len();
Expand All @@ -654,7 +664,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
epoch_id: &EpochId,
_last_known_block_hash: &CryptoHash,
account_id: &AccountId,
) -> Result<(ValidatorStake, bool), Error> {
) -> Result<(ValidatorStake, bool), EpochError> {
let validators = &self.validators_by_valset[self.get_valset_for_epoch(epoch_id)?];
for validator_stake in validators.block_producers.iter() {
if validator_stake.account_id() == account_id {
Expand All @@ -666,22 +676,22 @@ impl EpochManagerAdapter for KeyValueRuntime {
return Ok((validator_stake.clone(), false));
}
}
Err(Error::NotAValidator)
Err(EpochError::NotAValidator(account_id.clone(), epoch_id.clone()))
}

fn get_fisherman_by_account_id(
&self,
_epoch_id: &EpochId,
epoch_id: &EpochId,
_last_known_block_hash: &CryptoHash,
_account_id: &AccountId,
) -> Result<(ValidatorStake, bool), Error> {
Err(Error::NotAValidator)
account_id: &AccountId,
) -> Result<(ValidatorStake, bool), EpochError> {
Err(EpochError::NotAValidator(account_id.clone(), epoch_id.clone()))
}

fn get_validator_info(
&self,
_epoch_id: ValidatorInfoIdentifier,
) -> Result<EpochValidatorInfo, Error> {
) -> Result<EpochValidatorInfo, EpochError> {
Ok(EpochValidatorInfo {
current_validators: vec![],
next_validators: vec![],
Expand All @@ -694,11 +704,14 @@ impl EpochManagerAdapter for KeyValueRuntime {
})
}

fn get_epoch_minted_amount(&self, _epoch_id: &EpochId) -> Result<Balance, Error> {
fn get_epoch_minted_amount(&self, _epoch_id: &EpochId) -> Result<Balance, EpochError> {
Ok(0)
}

fn get_epoch_protocol_version(&self, _epoch_id: &EpochId) -> Result<ProtocolVersion, Error> {
fn get_epoch_protocol_version(
&self,
_epoch_id: &EpochId,
) -> Result<ProtocolVersion, EpochError> {
Ok(PROTOCOL_VERSION)
}

Expand All @@ -716,7 +729,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
Arc<EpochInfo>,
Arc<EpochInfo>,
),
Error,
EpochError,
> {
Ok(Default::default())
}
Expand All @@ -732,7 +745,7 @@ impl EpochManagerAdapter for KeyValueRuntime {
_epoch_info: EpochInfo,
_next_epoch_id: &EpochId,
_next_epoch_info: EpochInfo,
) -> Result<(), Error> {
) -> Result<(), EpochError> {
Ok(())
}

Expand Down
8 changes: 8 additions & 0 deletions chain/chunks-primitives/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt;

use near_primitives::errors::EpochError;

#[derive(Debug)]
pub enum Error {
InvalidPartMessage,
Expand Down Expand Up @@ -35,3 +37,9 @@ impl From<near_chain_primitives::Error> for Error {
Error::ChainError(err)
}
}

impl From<EpochError> for Error {
fn from(err: EpochError) -> Self {
Error::ChainError(err.into())
}
}
6 changes: 6 additions & 0 deletions chain/client-primitives/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ pub enum Error {
Other(String),
}

impl From<near_primitives::errors::EpochError> for Error {
fn from(err: near_primitives::errors::EpochError) -> Self {
Error::Chain(err.into())
}
}

#[derive(Clone, Debug, serde::Serialize, PartialEq)]
pub enum AccountOrPeerIdOrHash {
AccountId(AccountId),
Expand Down
Loading

0 comments on commit dc8f61c

Please sign in to comment.