Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make RuntimeExt: Send #11634

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions chain/chain/src/flat_storage_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use near_store::flat::{
use near_store::Store;
use near_store::{Trie, TrieDBStorage, TrieTraversalItem};
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use tracing::{debug, info};
Expand Down Expand Up @@ -97,7 +96,7 @@ impl FlatStorageShardCreator {
result_sender: Sender<u64>,
) {
let trie_storage = TrieDBStorage::new(store.clone(), shard_uid);
let trie = Trie::new(Rc::new(trie_storage), state_root, None);
let trie = Trie::new(Arc::new(trie_storage), state_root, None);
let path_begin = trie.find_state_part_boundary(part_id.idx, part_id.total).unwrap();
let path_end = trie.find_state_part_boundary(part_id.idx + 1, part_id.total).unwrap();
let hex_path_begin = Self::nibbles_to_hex(&path_begin);
Expand Down Expand Up @@ -199,7 +198,7 @@ impl FlatStorageShardCreator {
let trie_storage = TrieDBStorage::new(store, shard_uid);
let state_root =
*chain_store.get_chunk_extra(&block_hash, &shard_uid)?.state_root();
let trie = Trie::new(Rc::new(trie_storage), state_root, None);
let trie = Trie::new(Arc::new(trie_storage), state_root, None);
let root_node = trie.retrieve_root_node().unwrap();
let num_state_parts =
root_node.memory_usage / STATE_PART_MEMORY_LIMIT.as_u64() + 1;
Expand Down
3 changes: 2 additions & 1 deletion core/primitives/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ impl StateRootNode {
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
Eq,
Expand Down Expand Up @@ -1089,7 +1090,7 @@ pub enum TransactionOrReceiptId {

/// Provides information about current epoch validators.
/// Used to break dependency between epoch manager and runtime.
pub trait EpochInfoProvider {
pub trait EpochInfoProvider: Send + Sync {
/// Get current stake of a validator in the given epoch.
/// If the account is not a validator, returns `None`.
fn validator_stake(
Expand Down
9 changes: 4 additions & 5 deletions core/store/src/trie/accounting_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@ use near_primitives::errors::StorageError;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::ShardUId;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;

/// Switch that controls whether the `TrieAccountingCache` is enabled.
pub struct TrieAccountingCacheSwitch(Rc<std::cell::Cell<bool>>);
pub struct TrieAccountingCacheSwitch(Arc<std::sync::atomic::AtomicBool>);

impl TrieAccountingCacheSwitch {
pub fn set(&self, enabled: bool) {
self.0.set(enabled);
self.0.store(enabled, std::sync::atomic::Ordering::Relaxed);
}

pub fn enabled(&self) -> bool {
self.0.get()
self.0.load(std::sync::atomic::Ordering::Relaxed)
}
}

Expand Down Expand Up @@ -97,7 +96,7 @@ impl TrieAccountingCache {
}

pub fn enable_switch(&self) -> TrieAccountingCacheSwitch {
TrieAccountingCacheSwitch(Rc::clone(&self.enable.0))
TrieAccountingCacheSwitch(Arc::clone(&self.enable.0))
}

/// Retrieve raw bytes from the cache if it exists, otherwise retrieve it
Expand Down
9 changes: 4 additions & 5 deletions core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ use std::cell::RefCell;
use std::collections::{BTreeMap, HashSet};
use std::fmt::Write;
use std::hash::Hash;
use std::rc::Rc;
use std::str;
use std::sync::{Arc, RwLock, RwLockReadGuard};

Expand Down Expand Up @@ -336,7 +335,7 @@ impl std::fmt::Debug for TrieNode {
}

pub struct Trie {
storage: Rc<dyn TrieStorage>,
storage: Arc<dyn TrieStorage>,
memtries: Option<Arc<RwLock<MemTries>>>,
root: StateRoot,
/// If present, flat storage is used to look up keys (if asked for).
Expand Down Expand Up @@ -629,15 +628,15 @@ impl Trie {
/// By default, the accounting cache is not enabled. To enable or disable it
/// (only in this crate), call self.accounting_cache.borrow_mut().set_enabled().
pub fn new(
storage: Rc<dyn TrieStorage>,
storage: Arc<dyn TrieStorage>,
root: StateRoot,
flat_storage_chunk_view: Option<FlatStorageChunkView>,
) -> Self {
Self::new_with_memtries(storage, None, root, flat_storage_chunk_view)
}

pub fn new_with_memtries(
storage: Rc<dyn TrieStorage>,
storage: Arc<dyn TrieStorage>,
memtries: Option<Arc<RwLock<MemTries>>>,
root: StateRoot,
flat_storage_chunk_view: Option<FlatStorageChunkView>,
Expand Down Expand Up @@ -711,7 +710,7 @@ impl Trie {
) -> Self {
let PartialState::TrieValues(nodes) = partial_storage.nodes;
let recorded_storage = nodes.into_iter().map(|value| (hash(&value), value)).collect();
let storage = Rc::new(TrieMemoryPartialStorage::new(recorded_storage));
let storage = Arc::new(TrieMemoryPartialStorage::new(recorded_storage));
let mut trie = Self::new(storage, root, None);
trie.charge_gas_for_trie_node_access = !flat_storage_used;
trie
Expand Down
7 changes: 3 additions & 4 deletions core/store/src/trie/prefetching_trie_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use near_primitives::shard_layout::ShardUId;
use near_primitives::trie_key::TrieKey;
use near_primitives::types::{AccountId, ShardId, StateRoot};
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;
use std::thread;

Expand Down Expand Up @@ -446,8 +445,8 @@ impl PrefetchApi {
})
}

pub fn make_storage(&self) -> Rc<dyn TrieStorage> {
Rc::new(TriePrefetchingStorage::new(
pub fn make_storage(&self) -> Arc<dyn TrieStorage> {
Arc::new(TriePrefetchingStorage::new(
self.store.clone(),
self.shard_uid,
self.shard_cache.clone(),
Expand Down Expand Up @@ -488,7 +487,7 @@ impl PrefetchApi {
// the clone only clones a few `Arc`s, so the performance
// hit is small.
let prefetcher_trie =
Trie::new(Rc::new(prefetcher_storage.clone()), trie_root, None);
Trie::new(Arc::new(prefetcher_storage.clone()), trie_root, None);
let storage_key = trie_key.to_vec();
metric_prefetch_sent.inc();
match prefetcher_trie.get(&storage_key) {
Expand Down
5 changes: 2 additions & 3 deletions core/store/src/trie/shard_tries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use near_primitives::types::{
};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, Mutex, RwLock};
use tracing::info;

Expand Down Expand Up @@ -139,7 +138,7 @@ impl ShardTries {
.clone()
});

let storage = Rc::new(TrieCachingStorage::new(
let storage = Arc::new(TrieCachingStorage::new(
self.0.store.clone(),
cache,
shard_uid,
Expand All @@ -166,7 +165,7 @@ impl ShardTries {
) -> Result<Trie, StorageError> {
let (store, flat_storage_manager) = self.get_state_snapshot(block_hash)?;
let cache = self.get_trie_cache_for(shard_uid, true);
let storage = Rc::new(TrieCachingStorage::new(store, cache, shard_uid, true, None));
let storage = Arc::new(TrieCachingStorage::new(store, cache, shard_uid, true, None));
let flat_storage_chunk_view = flat_storage_manager.chunk_view(shard_uid, *block_hash);

Ok(Trie::new(storage, state_root, flat_storage_chunk_view))
Expand Down
7 changes: 3 additions & 4 deletions core/store/src/trie/state_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use near_primitives::state_record::is_contract_code_key;
use near_primitives::types::{ShardId, StateRoot};
use near_vm_runner::ContractCode;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use std::sync::Arc;

use super::TrieRefcountDeltaMap;
Expand Down Expand Up @@ -243,7 +242,7 @@ impl Trie {
.with_label_values(&[&shard_id.to_string()])
.start_timer();
let local_state_part_trie =
Trie::new(Rc::new(TrieMemoryPartialStorage::default()), StateRoot::new(), None);
Trie::new(Arc::new(TrieMemoryPartialStorage::default()), StateRoot::new(), None);
let local_state_part_nodes =
local_state_part_trie.update(all_state_part_items.into_iter())?.insertions;
let local_trie_creation_duration = local_trie_creation_timer.stop_and_record();
Expand All @@ -264,7 +263,7 @@ impl Trie {
.map(|entry| (*entry.hash(), entry.payload().to_vec().into())),
);
let final_trie =
Trie::new(Rc::new(TrieMemoryPartialStorage::new(all_nodes)), self.root, None);
Trie::new(Arc::new(TrieMemoryPartialStorage::new(all_nodes)), self.root, None);

final_trie.visit_nodes_for_state_part(part_id)?;
let final_trie_storage = final_trie.storage.as_partial_storage().unwrap();
Expand Down Expand Up @@ -434,7 +433,7 @@ impl Trie {
trie.visit_nodes_for_state_part(part_id)?;
let storage = trie.storage.as_partial_storage().unwrap();

if storage.visited_nodes.borrow().len() != num_nodes {
if storage.visited_nodes.read().expect("read visited_nodes").len() != num_nodes {
// As all nodes belonging to state part were visited, there is some
// unexpected data in downloaded state part.
return Err(StorageError::UnexpectedTrieValue);
Expand Down
9 changes: 4 additions & 5 deletions core/store/src/trie/trie_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use near_primitives::challenge::PartialState;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::ShardUId;
use near_primitives::types::ShardId;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet, VecDeque};
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -282,7 +281,7 @@ impl TrieCache {
}
}

pub trait TrieStorage {
pub trait TrieStorage: Send + Sync {
/// Get bytes of a serialized `TrieNode`.
///
/// # Errors
Expand All @@ -307,7 +306,7 @@ pub trait TrieStorage {
#[derive(Default)]
pub struct TrieMemoryPartialStorage {
pub(crate) recorded_storage: HashMap<CryptoHash, Arc<[u8]>>,
pub(crate) visited_nodes: RefCell<HashSet<CryptoHash>>,
pub(crate) visited_nodes: std::sync::RwLock<HashSet<CryptoHash>>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut feeling would have been Mutex to be a better match here, as it’s probably write-many-read-once. But I didn’t check all the uses and it’s likely unrelevant optimization anyway, so I’ll just leave that as an informational message if you want to think more about it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning here was that RwLock is a direct match to RefCell with regards to its API surface (borrow => read/borrow_mut => write.) We may indeed be able to revert this eventually; we'll see.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM :)

}

impl TrieStorage for TrieMemoryPartialStorage {
Expand All @@ -318,7 +317,7 @@ impl TrieStorage for TrieMemoryPartialStorage {
*hash,
));
if result.is_ok() {
self.visited_nodes.borrow_mut().insert(*hash);
self.visited_nodes.write().expect("write visited_nodes").insert(*hash);
}
result
}
Expand All @@ -334,7 +333,7 @@ impl TrieMemoryPartialStorage {
}

pub fn partial_state(&self) -> PartialState {
let touched_nodes = self.visited_nodes.borrow();
let touched_nodes = self.visited_nodes.read().expect("read visited_nodes");
let mut nodes: Vec<_> =
self.recorded_storage
.iter()
Expand Down
6 changes: 3 additions & 3 deletions core/store/src/trie/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use near_primitives::types::{
};
use near_vm_runner::ContractCode;
use std::collections::BTreeMap;
use std::rc::Rc;
use std::sync::Arc;

mod iterator;

Expand All @@ -20,11 +20,11 @@ mod iterator;
/// requesting and compiling contracts, as any contract code read and
/// compilation is a major bottleneck during chunk execution.
struct ContractStorage {
storage: Rc<dyn TrieStorage>,
storage: Arc<dyn TrieStorage>,
}

impl ContractStorage {
fn new(storage: Rc<dyn TrieStorage>) -> Self {
fn new(storage: Arc<dyn TrieStorage>) -> Self {
Self { storage }
}

Expand Down
9 changes: 3 additions & 6 deletions nearcore/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::rc::Rc;

use crate::NearConfig;
use actix_rt::ArbiterHandle;
use near_async::time::Duration;
use near_chain::{Block, ChainStore, ChainStoreAccess};
Expand All @@ -9,12 +8,10 @@ use near_o11y::metrics::{
try_create_int_gauge, try_create_int_gauge_vec, HistogramVec, IntCounterVec, IntGauge,
IntGaugeVec,
};

use near_primitives::{shard_layout::ShardLayout, state_record::StateRecord, trie_key};
use near_store::{ShardUId, Store, Trie, TrieDBStorage};
use once_cell::sync::Lazy;

use crate::NearConfig;
use std::sync::Arc;

pub(crate) static POSTPONED_RECEIPTS_COUNT: Lazy<IntGaugeVec> = Lazy::new(|| {
try_create_int_gauge_vec(
Expand Down Expand Up @@ -160,7 +157,7 @@ fn get_postponed_receipt_count_for_shard(
let chunk_extra = chain_store.get_chunk_extra(block.hash(), &shard_uid)?;
let state_root = chunk_extra.state_root();
let storage = TrieDBStorage::new(store.clone(), shard_uid);
let storage = Rc::new(storage);
let storage = Arc::new(storage);
let flat_storage_chunk_view = None;
let trie = Trie::new(storage, *state_root, flat_storage_chunk_view);
get_postponed_receipt_count_for_trie(trie)
Expand Down
2 changes: 1 addition & 1 deletion runtime/near-vm-runner/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub(crate) type VMResult<T = VMOutcome> = Result<T, VMRunnerError>;
))]
pub fn run(
method_name: &str,
ext: &mut dyn External,
ext: &mut (dyn External + Send),
context: &VMContext,
wasm_config: Arc<Config>,
fees_config: Arc<RuntimeFeesConfig>,
Expand Down
16 changes: 8 additions & 8 deletions runtime/runtime/src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub(crate) fn execute_function_call(
is_last_action: bool,
view_config: Option<ViewConfig>,
) -> Result<VMOutcome, RuntimeError> {
let account_id = runtime_ext.account_id();
let account_id = runtime_ext.account_id().clone();
tracing::debug!(target: "runtime", %account_id, "Calling the contract");
// Output data receipts are ignored if the function call is not the last action in the batch.
let output_data_receivers: Vec<_> = if is_last_action {
Expand Down Expand Up @@ -180,7 +180,7 @@ pub(crate) fn action_function_call(
action_hash: &CryptoHash,
config: &RuntimeConfig,
is_last_action: bool,
epoch_info_provider: &dyn EpochInfoProvider,
epoch_info_provider: &(dyn EpochInfoProvider),
) -> Result<(), RuntimeError> {
if account.amount().checked_add(function_call.deposit).is_none() {
return Err(StorageError::StorageInconsistentState(
Expand All @@ -193,12 +193,12 @@ pub(crate) fn action_function_call(
let mut runtime_ext = RuntimeExt::new(
state_update,
&mut receipt_manager,
account_id,
account,
action_hash,
&apply_state.epoch_id,
&apply_state.prev_block_hash,
&apply_state.block_hash,
account_id.clone(),
account.clone(),
*action_hash,
apply_state.epoch_id,
apply_state.prev_block_hash,
apply_state.block_hash,
epoch_info_provider,
apply_state.current_protocol_version,
);
Expand Down
Loading
Loading