Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

fix: state trait does not require &mut self #1325

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 81 additions & 69 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, MutexGuard};

Expand Down Expand Up @@ -28,8 +29,9 @@ pub type ContractClassMapping = HashMap<ClassHash, ContractClass>;
pub struct CachedState<S: StateReader> {
pub state: S,
// Invariant: read/write access is managed by CachedState.
cache: StateCache,
class_hash_to_class: ContractClassMapping,
// Using interior mutability to update caches during `State`'s immutable getters.
cache: RefCell<StateCache>,
class_hash_to_class: RefCell<ContractClassMapping>,
// Invariant: managed by CachedState.
global_class_hash_to_class: GlobalContractCache,
/// A map from class hash to the set of PC values that were visited in the class.
Expand All @@ -40,8 +42,8 @@ impl<S: StateReader> CachedState<S> {
pub fn new(state: S, global_class_hash_to_class: GlobalContractCache) -> Self {
Self {
state,
cache: StateCache::default(),
class_hash_to_class: HashMap::default(),
cache: RefCell::new(StateCache::default()),
class_hash_to_class: RefCell::new(HashMap::default()),
global_class_hash_to_class,
visited_pcs: HashMap::default(),
}
Expand All @@ -60,20 +62,21 @@ impl<S: StateReader> CachedState<S> {
/// root); the state updates correspond to them.
pub fn get_actual_state_changes(&mut self) -> StateResult<StateChanges> {
self.update_initial_values_of_write_only_access()?;
let cache = self.cache.borrow();

Ok(StateChanges {
storage_updates: self.cache.get_storage_updates(),
nonce_updates: self.cache.get_nonce_updates(),
storage_updates: cache.get_storage_updates(),
nonce_updates: cache.get_nonce_updates(),
// Class hash updates (deployed contracts + replace_class syscall).
class_hash_updates: self.cache.get_class_hash_updates(),
class_hash_updates: cache.get_class_hash_updates(),
// Compiled class hash updates (declare Cairo 1 contract).
compiled_class_hash_updates: self.cache.get_compiled_class_hash_updates(),
compiled_class_hash_updates: cache.get_compiled_class_hash_updates(),
})
}

/// Drains contract-class cache collected during execution and updates the global cache.
pub fn move_classes_to_global_cache(&mut self) {
let contract_class_updates: Vec<_> = self.class_hash_to_class.drain().collect();
let contract_class_updates: Vec<_> = self.class_hash_to_class.get_mut().drain().collect();
for (key, value) in contract_class_updates {
self.global_class_hash_to_class().cache_set(key, value);
}
Expand All @@ -82,24 +85,25 @@ impl<S: StateReader> CachedState<S> {
// Locks the Mutex and unwraps the MutexGuard, thus exposing the internal cache
// store. The Guard will panic only if the Mutex panics during the lock operation, but
// this shouldn't happen in our flow.
// Note: `&mut` is used since the LRU cache updates internal counters on reads.
pub fn global_class_hash_to_class(&mut self) -> LockedContractClassCache<'_> {
pub fn global_class_hash_to_class(&self) -> LockedContractClassCache<'_> {
self.global_class_hash_to_class.lock()
}

pub fn update_cache(&mut self, cache_updates: StateCache) {
self.cache.nonce_writes.extend(cache_updates.nonce_writes);
self.cache.class_hash_writes.extend(cache_updates.class_hash_writes);
self.cache.storage_writes.extend(cache_updates.storage_writes);
self.cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes);
let mut cache = self.cache.borrow_mut();

cache.nonce_writes.extend(cache_updates.nonce_writes);
cache.class_hash_writes.extend(cache_updates.class_hash_writes);
cache.storage_writes.extend(cache_updates.storage_writes);
cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes);
}

pub fn update_contract_class_caches(
&mut self,
local_contract_cache_updates: ContractClassMapping,
global_contract_cache: GlobalContractCache,
) {
self.class_hash_to_class.extend(local_contract_cache_updates);
self.class_hash_to_class.get_mut().extend(local_contract_cache_updates);
self.global_class_hash_to_class = global_contract_cache;
}

Expand All @@ -115,31 +119,33 @@ impl<S: StateReader> CachedState<S> {
/// Same for class hash and nonce writes.
// TODO(Noa, 30/07/23): Consider adding DB getters in bulk (via a DB read transaction).
fn update_initial_values_of_write_only_access(&mut self) -> StateResult<()> {
let cache = &mut *self.cache.borrow_mut();

// Eliminate storage writes that are identical to the initial value (no change). Assumes
// that `set_storage_at` does not affect the state field.
for contract_storage_key in self.cache.storage_writes.keys() {
if !self.cache.storage_initial_values.contains_key(contract_storage_key) {
for contract_storage_key in cache.storage_writes.keys() {
if !cache.storage_initial_values.contains_key(contract_storage_key) {
// First access to this cell was write; cache initial value.
self.cache.storage_initial_values.insert(
cache.storage_initial_values.insert(
*contract_storage_key,
self.state.get_storage_at(contract_storage_key.0, contract_storage_key.1)?,
);
}
}

for contract_address in self.cache.class_hash_writes.keys() {
if !self.cache.class_hash_initial_values.contains_key(contract_address) {
for contract_address in cache.class_hash_writes.keys() {
if !cache.class_hash_initial_values.contains_key(contract_address) {
// First access to this cell was write; cache initial value.
self.cache
cache
.class_hash_initial_values
.insert(*contract_address, self.state.get_class_hash_at(*contract_address)?);
}
}

for contract_address in self.cache.nonce_writes.keys() {
if !self.cache.nonce_initial_values.contains_key(contract_address) {
for contract_address in cache.nonce_writes.keys() {
if !cache.nonce_initial_values.contains_key(contract_address) {
// First access to this cell was write; cache initial value.
self.cache
cache
.nonce_initial_values
.insert(*contract_address, self.state.get_nonce_at(*contract_address)?);
}
Expand All @@ -161,83 +167,89 @@ impl<S: StateReader> From<S> for CachedState<S> {

impl<S: StateReader> StateReader for CachedState<S> {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
if self.cache.get_storage_at(contract_address, key).is_none() {
let mut cache = self.cache.borrow_mut();

if cache.get_storage_at(contract_address, key).is_none() {
let storage_value = self.state.get_storage_at(contract_address, key)?;
self.cache.set_storage_initial_value(contract_address, key, storage_value);
cache.set_storage_initial_value(contract_address, key, storage_value);
}

let value = self.cache.get_storage_at(contract_address, key).unwrap_or_else(|| {
let value = cache.get_storage_at(contract_address, key).unwrap_or_else(|| {
panic!("Cannot retrieve '{contract_address:?}' and '{key:?}' from the cache.")
});
Ok(*value)
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
if self.cache.get_nonce_at(contract_address).is_none() {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let mut cache = self.cache.borrow_mut();

if cache.get_nonce_at(contract_address).is_none() {
let nonce = self.state.get_nonce_at(contract_address)?;
self.cache.set_nonce_initial_value(contract_address, nonce);
cache.set_nonce_initial_value(contract_address, nonce);
}

let nonce = self
.cache
let nonce = cache
.get_nonce_at(contract_address)
.unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache."));

Ok(*nonce)
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
if self.cache.get_class_hash_at(contract_address).is_none() {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let mut cache = self.cache.borrow_mut();

if cache.get_class_hash_at(contract_address).is_none() {
let class_hash = self.state.get_class_hash_at(contract_address)?;
self.cache.set_class_hash_initial_value(contract_address, class_hash);
cache.set_class_hash_initial_value(contract_address, class_hash);
}

let class_hash = self
.cache
let class_hash = cache
.get_class_hash_at(contract_address)
.unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache."));
Ok(*class_hash)
}

#[allow(clippy::map_entry)]
// Clippy solution don't work because it required two mutable ref to self
// Could probably be solved with interior mutability
fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
if !self.class_hash_to_class.contains_key(&class_hash) {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut();

if let std::collections::hash_map::Entry::Vacant(vacant_entry) =
class_hash_to_class.entry(class_hash)
{
let contract_class = self.global_class_hash_to_class().cache_get(&class_hash).cloned();

match contract_class {
Some(contract_class_from_global_cache) => {
self.class_hash_to_class.insert(class_hash, contract_class_from_global_cache);
vacant_entry.insert(contract_class_from_global_cache);
}
None => {
let contract_class_from_db =
self.state.get_compiled_contract_class(class_hash)?;
self.class_hash_to_class.insert(class_hash, contract_class_from_db);
vacant_entry.insert(contract_class_from_db);
}
}
}

let contract_class = self
.class_hash_to_class
let contract_class = class_hash_to_class
.get(&class_hash)
.cloned()
.expect("The class hash must appear in the cache.");

Ok(contract_class)
}

fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
if self.cache.get_compiled_class_hash(class_hash).is_none() {
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
let mut cache = self.cache.borrow_mut();

if cache.get_compiled_class_hash(class_hash).is_none() {
let compiled_class_hash = self.state.get_compiled_class_hash(class_hash)?;
self.cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash);
cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash);
}

let compiled_class_hash = self
.cache
let compiled_class_hash = cache
.get_compiled_class_hash(class_hash)
.unwrap_or_else(|| panic!("Cannot retrieve '{class_hash:?}' from the cache."));
Ok(*compiled_class_hash)
Expand All @@ -251,7 +263,7 @@ impl<S: StateReader> State for CachedState<S> {
key: StorageKey,
value: StarkFelt,
) -> StateResult<()> {
self.cache.set_storage_value(contract_address, key, value);
self.cache.get_mut().set_storage_value(contract_address, key, value);

Ok(())
}
Expand All @@ -264,7 +276,7 @@ impl<S: StateReader> State for CachedState<S> {
usize::try_from(current_nonce.0)?.try_into().expect("Failed to convert usize to u64.");
let next_nonce_val = 1_u64 + current_nonce_as_u64;
let next_nonce = Nonce(StarkFelt::from(next_nonce_val));
self.cache.set_nonce_value(contract_address, next_nonce);
self.cache.get_mut().set_nonce_value(contract_address, next_nonce);

Ok(())
}
Expand All @@ -278,7 +290,7 @@ impl<S: StateReader> State for CachedState<S> {
return Err(StateError::OutOfRangeContractAddress);
}

self.cache.set_class_hash_write(contract_address, class_hash);
self.cache.get_mut().set_class_hash_write(contract_address, class_hash);
Ok(())
}

Expand All @@ -287,7 +299,7 @@ impl<S: StateReader> State for CachedState<S> {
class_hash: ClassHash,
contract_class: ContractClass,
) -> StateResult<()> {
self.class_hash_to_class.insert(class_hash, contract_class);
self.class_hash_to_class.get_mut().insert(class_hash, contract_class);
Ok(())
}

Expand All @@ -296,7 +308,7 @@ impl<S: StateReader> State for CachedState<S> {
class_hash: ClassHash,
compiled_class_hash: CompiledClassHash,
) -> StateResult<()> {
self.cache.set_compiled_class_hash_write(class_hash, compiled_class_hash);
self.cache.get_mut().set_compiled_class_hash_write(class_hash, compiled_class_hash);
Ok(())
}

Expand All @@ -307,7 +319,7 @@ impl<S: StateReader> State for CachedState<S> {
self.update_initial_values_of_write_only_access()
.unwrap_or_else(|_| panic!("Cannot convert stateDiff to CommitmentStateDiff."));

let state_cache = &self.cache;
let state_cache = self.cache.borrow();
let class_hash_updates = state_cache.get_class_hash_updates();
let storage_diffs = state_cache.get_storage_updates();
let nonces = state_cache.get_nonce_updates();
Expand Down Expand Up @@ -501,26 +513,26 @@ impl<'a, S: State + ?Sized> MutRefState<'a, S> {
/// Proxies inner object to expose `State` functionality.
impl<'a, S: State + ?Sized> StateReader for MutRefState<'a, S> {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
self.0.get_storage_at(contract_address, key)
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
self.0.get_nonce_at(contract_address)
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
self.0.get_class_hash_at(contract_address)
}

fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
self.0.get_compiled_contract_class(class_hash)
}

fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
self.0.get_compiled_class_hash(class_hash)
}
}
Expand Down Expand Up @@ -591,8 +603,8 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
..
} = self;
StagedTransactionalState {
cache,
class_hash_to_class,
cache: cache.into_inner(),
class_hash_to_class: class_hash_to_class.into_inner(),
global_class_hash_to_class,
tx_executed_class_hashes,
tx_visited_storage_entries,
Expand All @@ -604,10 +616,10 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
/// Commits changes in the child (wrapping) state to its parent.
pub fn commit(self) {
let state = self.state.0;
let child_cache = self.cache;
let child_cache = self.cache.into_inner();
state.update_cache(child_cache);
state.update_contract_class_caches(
self.class_hash_to_class,
self.class_hash_to_class.into_inner(),
self.global_class_hash_to_class,
);
state.update_visited_pcs_cache(&self.visited_pcs);
Expand Down Expand Up @@ -810,7 +822,7 @@ pub const GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST: usize = 100;
impl GlobalContractCache {
/// Locks the cache for atomic access. Although conceptually shared, writing to this cache is
/// only possible for one writer at a time.
pub fn lock(&mut self) -> LockedContractClassCache<'_> {
pub fn lock(&self) -> LockedContractClassCache<'_> {
self.0.lock().expect("Global contract cache is poisoned.")
}

Expand Down
Loading
Loading