Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Commit

Permalink
fix: State trait does not require &mut self
Browse files Browse the repository at this point in the history
fix: CachedState use interior mutability
  • Loading branch information
tdelabro committed Jan 10, 2024
1 parent d6dd88b commit 014b4a8
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 103 deletions.
148 changes: 81 additions & 67 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, MutexGuard};

Expand Down Expand Up @@ -28,8 +29,8 @@ pub type ContractClassMapping = HashMap<ClassHash, ContractClass>;
pub struct CachedState<S: StateReader> {
pub state: S,
// Invariant: read/write access is managed by CachedState.
cache: StateCache,
class_hash_to_class: ContractClassMapping,
cache: RefCell<StateCache>,
class_hash_to_class: RefCell<ContractClassMapping>,
// Invariant: managed by CachedState.
global_class_hash_to_class: GlobalContractCache,
}
Expand All @@ -38,8 +39,8 @@ impl<S: StateReader> CachedState<S> {
pub fn new(state: S, global_class_hash_to_class: GlobalContractCache) -> Self {
Self {
state,
cache: StateCache::default(),
class_hash_to_class: HashMap::default(),
cache: RefCell::new(StateCache::default()),
class_hash_to_class: RefCell::new(HashMap::default()),
global_class_hash_to_class,
}
}
Expand All @@ -61,22 +62,23 @@ impl<S: StateReader> CachedState<S> {
sender_address: Option<ContractAddress>,
) -> StateResult<StateChanges> {
self.update_initial_values_of_write_only_access()?;
let cache = self.cache.borrow();

// Storage Update.
let storage_updates = &mut self.cache.get_storage_updates();
let storage_updates = &mut cache.get_storage_updates();
let mut modified_contracts: HashSet<ContractAddress> =
storage_updates.keys().map(|address_key_pair| address_key_pair.0).collect();

// Class hash Update (deployed contracts + replace_class syscall).
let class_hash_updates = &self.cache.get_class_hash_updates();
let class_hash_updates = &cache.get_class_hash_updates();
modified_contracts.extend(class_hash_updates.keys());

// Nonce updates.
let nonce_updates = &self.cache.get_nonce_updates();
let nonce_updates = &cache.get_nonce_updates();
modified_contracts.extend(nonce_updates.keys());

// Compiled class hash updates (declare Cairo 1 contract).
let compiled_class_hash_updates = &self.cache.get_compiled_class_hash_updates();
let compiled_class_hash_updates = &cache.get_compiled_class_hash_updates();

// For account transactions, we need to compute the transaction fee before we can execute
// the fee transfer, and the fee should cover the state changes that happen in the
Expand Down Expand Up @@ -105,7 +107,7 @@ impl<S: StateReader> CachedState<S> {

/// Drains contract-class cache collected during execution and updates the global cache.
pub fn move_classes_to_global_cache(&mut self) {
let contract_class_updates: Vec<_> = self.class_hash_to_class.drain().collect();
let contract_class_updates: Vec<_> = self.class_hash_to_class.get_mut().drain().collect();
for (key, value) in contract_class_updates {
self.global_class_hash_to_class().cache_set(key, value);
}
Expand All @@ -114,24 +116,25 @@ impl<S: StateReader> CachedState<S> {
// Locks the Mutex and unwraps the MutexGuard, thus exposing the internal cache
// store. The Guard will panic only if the Mutex panics during the lock operation, but
// this shouldn't happen in our flow.
// Note: `&mut` is used since the LRU cache updates internal counters on reads.
pub fn global_class_hash_to_class(&mut self) -> LockedContractClassCache<'_> {
pub fn global_class_hash_to_class(&self) -> LockedContractClassCache<'_> {
self.global_class_hash_to_class.lock()
}

pub fn update_cache(&mut self, cache_updates: StateCache) {
self.cache.nonce_writes.extend(cache_updates.nonce_writes);
self.cache.class_hash_writes.extend(cache_updates.class_hash_writes);
self.cache.storage_writes.extend(cache_updates.storage_writes);
self.cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes);
let mut cache = self.cache.borrow_mut();

cache.nonce_writes.extend(cache_updates.nonce_writes);
cache.class_hash_writes.extend(cache_updates.class_hash_writes);
cache.storage_writes.extend(cache_updates.storage_writes);
cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes);
}

pub fn update_contract_class_caches(
&mut self,
local_contract_cache_updates: ContractClassMapping,
global_contract_cache: GlobalContractCache,
) {
self.class_hash_to_class.extend(local_contract_cache_updates);
self.class_hash_to_class.get_mut().extend(local_contract_cache_updates);
self.global_class_hash_to_class = global_contract_cache;
}

Expand All @@ -141,31 +144,33 @@ impl<S: StateReader> CachedState<S> {
/// Same for class hash and nonce writes.
// TODO(Noa, 30/07/23): Consider adding DB getters in bulk (via a DB read transaction).
fn update_initial_values_of_write_only_access(&mut self) -> StateResult<()> {
let cache = &mut *self.cache.borrow_mut();

// Eliminate storage writes that are identical to the initial value (no change). Assumes
// that `set_storage_at` does not affect the state field.
for contract_storage_key in self.cache.storage_writes.keys() {
if !self.cache.storage_initial_values.contains_key(contract_storage_key) {
for contract_storage_key in cache.storage_writes.keys() {
if !cache.storage_initial_values.contains_key(contract_storage_key) {
// First access to this cell was write; cache initial value.
self.cache.storage_initial_values.insert(
cache.storage_initial_values.insert(
*contract_storage_key,
self.state.get_storage_at(contract_storage_key.0, contract_storage_key.1)?,
);
}
}

for contract_address in self.cache.class_hash_writes.keys() {
if !self.cache.class_hash_initial_values.contains_key(contract_address) {
for contract_address in cache.class_hash_writes.keys() {
if !cache.class_hash_initial_values.contains_key(contract_address) {
// First access to this cell was write; cache initial value.
self.cache
cache
.class_hash_initial_values
.insert(*contract_address, self.state.get_class_hash_at(*contract_address)?);
}
}

for contract_address in self.cache.nonce_writes.keys() {
if !self.cache.nonce_initial_values.contains_key(contract_address) {
for contract_address in cache.nonce_writes.keys() {
if !cache.nonce_initial_values.contains_key(contract_address) {
// First access to this cell was write; cache initial value.
self.cache
cache
.nonce_initial_values
.insert(*contract_address, self.state.get_nonce_at(*contract_address)?);
}
Expand All @@ -183,42 +188,47 @@ impl<S: StateReader> From<S> for CachedState<S> {

impl<S: StateReader> StateReader for CachedState<S> {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
if self.cache.get_storage_at(contract_address, key).is_none() {
let mut cache = self.cache.borrow_mut();

if cache.get_storage_at(contract_address, key).is_none() {
let storage_value = self.state.get_storage_at(contract_address, key)?;
self.cache.set_storage_initial_value(contract_address, key, storage_value);
cache.set_storage_initial_value(contract_address, key, storage_value);
}

let value = self.cache.get_storage_at(contract_address, key).unwrap_or_else(|| {
let value = cache.get_storage_at(contract_address, key).unwrap_or_else(|| {
panic!("Cannot retrieve '{contract_address:?}' and '{key:?}' from the cache.")
});
Ok(*value)
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
if self.cache.get_nonce_at(contract_address).is_none() {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
let mut cache = self.cache.borrow_mut();

if cache.get_nonce_at(contract_address).is_none() {
let nonce = self.state.get_nonce_at(contract_address)?;
self.cache.set_nonce_initial_value(contract_address, nonce);
cache.set_nonce_initial_value(contract_address, nonce);
}

let nonce = self
.cache
let nonce = cache
.get_nonce_at(contract_address)
.unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache."));

Ok(*nonce)
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
if self.cache.get_class_hash_at(contract_address).is_none() {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
let mut cache = self.cache.borrow_mut();

if cache.get_class_hash_at(contract_address).is_none() {
let class_hash = self.state.get_class_hash_at(contract_address)?;
self.cache.set_class_hash_initial_value(contract_address, class_hash);
cache.set_class_hash_initial_value(contract_address, class_hash);
}

let class_hash = self
.cache
let class_hash = cache
.get_class_hash_at(contract_address)
.unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache."));
Ok(*class_hash)
Expand All @@ -227,39 +237,41 @@ impl<S: StateReader> StateReader for CachedState<S> {
#[allow(clippy::map_entry)]
// Clippy solution don't work because it required two mutable ref to self
// Could probably be solved with interior mutability
fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
if !self.class_hash_to_class.contains_key(&class_hash) {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut();

if !class_hash_to_class.contains_key(&class_hash) {
let contract_class = self.global_class_hash_to_class().cache_get(&class_hash).cloned();

match contract_class {
Some(contract_class_from_global_cache) => {
self.class_hash_to_class.insert(class_hash, contract_class_from_global_cache);
class_hash_to_class.insert(class_hash, contract_class_from_global_cache);
}
None => {
let contract_class_from_db =
self.state.get_compiled_contract_class(class_hash)?;
self.class_hash_to_class.insert(class_hash, contract_class_from_db);
class_hash_to_class.insert(class_hash, contract_class_from_db);
}
}
}

let contract_class = self
.class_hash_to_class
let contract_class = class_hash_to_class
.get(&class_hash)
.cloned()
.expect("The class hash must appear in the cache.");

Ok(contract_class)
}

fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
if self.cache.get_compiled_class_hash(class_hash).is_none() {
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
let mut cache = self.cache.borrow_mut();

if cache.get_compiled_class_hash(class_hash).is_none() {
let compiled_class_hash = self.state.get_compiled_class_hash(class_hash)?;
self.cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash);
cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash);
}

let compiled_class_hash = self
.cache
let compiled_class_hash = cache
.get_compiled_class_hash(class_hash)
.unwrap_or_else(|| panic!("Cannot retrieve '{class_hash:?}' from the cache."));
Ok(*compiled_class_hash)
Expand All @@ -273,7 +285,7 @@ impl<S: StateReader> State for CachedState<S> {
key: StorageKey,
value: StarkFelt,
) -> StateResult<()> {
self.cache.set_storage_value(contract_address, key, value);
self.cache.get_mut().set_storage_value(contract_address, key, value);

Ok(())
}
Expand All @@ -283,7 +295,7 @@ impl<S: StateReader> State for CachedState<S> {
let current_nonce_as_u64 = usize::try_from(current_nonce.0)? as u64;
let next_nonce_val = 1_u64 + current_nonce_as_u64;
let next_nonce = Nonce(StarkFelt::from(next_nonce_val));
self.cache.set_nonce_value(contract_address, next_nonce);
self.cache.get_mut().set_nonce_value(contract_address, next_nonce);

Ok(())
}
Expand All @@ -297,7 +309,7 @@ impl<S: StateReader> State for CachedState<S> {
return Err(StateError::OutOfRangeContractAddress);
}

self.cache.set_class_hash_write(contract_address, class_hash);
self.cache.get_mut().set_class_hash_write(contract_address, class_hash);
Ok(())
}

Expand All @@ -306,7 +318,7 @@ impl<S: StateReader> State for CachedState<S> {
class_hash: ClassHash,
contract_class: ContractClass,
) -> StateResult<()> {
self.class_hash_to_class.insert(class_hash, contract_class);
self.class_hash_to_class.get_mut().insert(class_hash, contract_class);
Ok(())
}

Expand All @@ -315,7 +327,7 @@ impl<S: StateReader> State for CachedState<S> {
class_hash: ClassHash,
compiled_class_hash: CompiledClassHash,
) -> StateResult<()> {
self.cache.set_compiled_class_hash_write(class_hash, compiled_class_hash);
self.cache.get_mut().set_compiled_class_hash_write(class_hash, compiled_class_hash);
Ok(())
}

Expand All @@ -326,7 +338,7 @@ impl<S: StateReader> State for CachedState<S> {
self.update_initial_values_of_write_only_access()
.unwrap_or_else(|_| panic!("Cannot convert stateDiff to CommitmentStateDiff."));

let state_cache = &self.cache;
let state_cache = self.cache.borrow();
let class_hash_updates = state_cache.get_class_hash_updates();
let storage_diffs = state_cache.get_storage_updates();
let nonces =
Expand Down Expand Up @@ -514,26 +526,26 @@ impl<'a, S: State + ?Sized> MutRefState<'a, S> {
/// Proxies inner object to expose `State` functionality.
impl<'a, S: State + ?Sized> StateReader for MutRefState<'a, S> {
fn get_storage_at(
&mut self,
&self,
contract_address: ContractAddress,
key: StorageKey,
) -> StateResult<StarkFelt> {
self.0.get_storage_at(contract_address, key)
}

fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult<Nonce> {
fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult<Nonce> {
self.0.get_nonce_at(contract_address)
}

fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult<ClassHash> {
fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult<ClassHash> {
self.0.get_class_hash_at(contract_address)
}

fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
self.0.get_compiled_contract_class(class_hash)
}

fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult<CompiledClassHash> {
self.0.get_compiled_class_hash(class_hash)
}
}
Expand Down Expand Up @@ -594,8 +606,8 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
let TransactionalState { cache, class_hash_to_class, global_class_hash_to_class, .. } =
self;
StagedTransactionalState {
cache,
class_hash_to_class,
cache: cache.into_inner(),
class_hash_to_class: class_hash_to_class.into_inner(),
global_class_hash_to_class,
tx_executed_class_hashes,
tx_visited_storage_entries,
Expand All @@ -605,10 +617,12 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
/// Commits changes in the child (wrapping) state to its parent.
pub fn commit(self) {
let state = self.state.0;
let child_cache = self.cache;
let child_cache = self.cache.into_inner();
state.update_cache(child_cache);
state
.update_contract_class_caches(self.class_hash_to_class, self.global_class_hash_to_class)
state.update_contract_class_caches(
self.class_hash_to_class.into_inner(),
self.global_class_hash_to_class,
)
}

/// Drops `self`.
Expand Down Expand Up @@ -701,7 +715,7 @@ impl GlobalContractCache {

/// Locks the cache for atomic access. Although conceptually shared, writing to this cache is
/// only possible for one writer at a time.
pub fn lock(&mut self) -> LockedContractClassCache<'_> {
pub fn lock(&self) -> LockedContractClassCache<'_> {
self.0.lock().expect("Global contract cache is poisoned.")
}

Expand Down
Loading

0 comments on commit 014b4a8

Please sign in to comment.