From 7115e388a8e36cd9c32a3400eee948366779b5db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Delabrouille?= Date: Wed, 10 Jan 2024 23:19:54 +0100 Subject: [PATCH 1/3] fix: state trait does not require &mut self fix: sachedState use interior mutability --- crates/blockifier/src/state/cached_state.rs | 144 ++++++++++-------- .../blockifier/src/state/cached_state_test.rs | 26 ++-- crates/blockifier/src/state/state_api.rs | 10 +- .../src/test_utils/dict_state_reader.rs | 10 +- .../src/state_readers/papyrus_state.rs | 13 +- .../src/state_readers/py_state_reader.rs | 10 +- 6 files changed, 111 insertions(+), 102 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 718101d1eb..b9f7dd8012 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::sync::{Arc, Mutex, MutexGuard}; @@ -28,8 +29,8 @@ pub type ContractClassMapping = HashMap; pub struct CachedState { pub state: S, // Invariant: read/write access is managed by CachedState. - cache: StateCache, - class_hash_to_class: ContractClassMapping, + cache: RefCell, + class_hash_to_class: RefCell, // Invariant: managed by CachedState. global_class_hash_to_class: GlobalContractCache, /// A map from class hash to the set of PC values that were visited in the class. @@ -40,8 +41,8 @@ impl CachedState { pub fn new(state: S, global_class_hash_to_class: GlobalContractCache) -> Self { Self { state, - cache: StateCache::default(), - class_hash_to_class: HashMap::default(), + cache: RefCell::new(StateCache::default()), + class_hash_to_class: RefCell::new(HashMap::default()), global_class_hash_to_class, visited_pcs: HashMap::default(), } @@ -60,20 +61,21 @@ impl CachedState { /// root); the state updates correspond to them. pub fn get_actual_state_changes(&mut self) -> StateResult { self.update_initial_values_of_write_only_access()?; + let cache = self.cache.borrow(); Ok(StateChanges { - storage_updates: self.cache.get_storage_updates(), - nonce_updates: self.cache.get_nonce_updates(), + storage_updates: cache.get_storage_updates(), + nonce_updates: cache.get_nonce_updates(), // Class hash updates (deployed contracts + replace_class syscall). - class_hash_updates: self.cache.get_class_hash_updates(), + class_hash_updates: cache.get_class_hash_updates(), // Compiled class hash updates (declare Cairo 1 contract). - compiled_class_hash_updates: self.cache.get_compiled_class_hash_updates(), + compiled_class_hash_updates: cache.get_compiled_class_hash_updates(), }) } /// Drains contract-class cache collected during execution and updates the global cache. pub fn move_classes_to_global_cache(&mut self) { - let contract_class_updates: Vec<_> = self.class_hash_to_class.drain().collect(); + let contract_class_updates: Vec<_> = self.class_hash_to_class.get_mut().drain().collect(); for (key, value) in contract_class_updates { self.global_class_hash_to_class().cache_set(key, value); } @@ -82,16 +84,17 @@ impl CachedState { // Locks the Mutex and unwraps the MutexGuard, thus exposing the internal cache // store. The Guard will panic only if the Mutex panics during the lock operation, but // this shouldn't happen in our flow. - // Note: `&mut` is used since the LRU cache updates internal counters on reads. - pub fn global_class_hash_to_class(&mut self) -> LockedContractClassCache<'_> { + pub fn global_class_hash_to_class(&self) -> LockedContractClassCache<'_> { self.global_class_hash_to_class.lock() } pub fn update_cache(&mut self, cache_updates: StateCache) { - self.cache.nonce_writes.extend(cache_updates.nonce_writes); - self.cache.class_hash_writes.extend(cache_updates.class_hash_writes); - self.cache.storage_writes.extend(cache_updates.storage_writes); - self.cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes); + let mut cache = self.cache.borrow_mut(); + + cache.nonce_writes.extend(cache_updates.nonce_writes); + cache.class_hash_writes.extend(cache_updates.class_hash_writes); + cache.storage_writes.extend(cache_updates.storage_writes); + cache.compiled_class_hash_writes.extend(cache_updates.compiled_class_hash_writes); } pub fn update_contract_class_caches( @@ -99,7 +102,7 @@ impl CachedState { local_contract_cache_updates: ContractClassMapping, global_contract_cache: GlobalContractCache, ) { - self.class_hash_to_class.extend(local_contract_cache_updates); + self.class_hash_to_class.get_mut().extend(local_contract_cache_updates); self.global_class_hash_to_class = global_contract_cache; } @@ -115,31 +118,33 @@ impl CachedState { /// Same for class hash and nonce writes. // TODO(Noa, 30/07/23): Consider adding DB getters in bulk (via a DB read transaction). fn update_initial_values_of_write_only_access(&mut self) -> StateResult<()> { + let cache = &mut *self.cache.borrow_mut(); + // Eliminate storage writes that are identical to the initial value (no change). Assumes // that `set_storage_at` does not affect the state field. - for contract_storage_key in self.cache.storage_writes.keys() { - if !self.cache.storage_initial_values.contains_key(contract_storage_key) { + for contract_storage_key in cache.storage_writes.keys() { + if !cache.storage_initial_values.contains_key(contract_storage_key) { // First access to this cell was write; cache initial value. - self.cache.storage_initial_values.insert( + cache.storage_initial_values.insert( *contract_storage_key, self.state.get_storage_at(contract_storage_key.0, contract_storage_key.1)?, ); } } - for contract_address in self.cache.class_hash_writes.keys() { - if !self.cache.class_hash_initial_values.contains_key(contract_address) { + for contract_address in cache.class_hash_writes.keys() { + if !cache.class_hash_initial_values.contains_key(contract_address) { // First access to this cell was write; cache initial value. - self.cache + cache .class_hash_initial_values .insert(*contract_address, self.state.get_class_hash_at(*contract_address)?); } } - for contract_address in self.cache.nonce_writes.keys() { - if !self.cache.nonce_initial_values.contains_key(contract_address) { + for contract_address in cache.nonce_writes.keys() { + if !cache.nonce_initial_values.contains_key(contract_address) { // First access to this cell was write; cache initial value. - self.cache + cache .nonce_initial_values .insert(*contract_address, self.state.get_nonce_at(*contract_address)?); } @@ -161,42 +166,47 @@ impl From for CachedState { impl StateReader for CachedState { fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult { - if self.cache.get_storage_at(contract_address, key).is_none() { + let mut cache = self.cache.borrow_mut(); + + if cache.get_storage_at(contract_address, key).is_none() { let storage_value = self.state.get_storage_at(contract_address, key)?; - self.cache.set_storage_initial_value(contract_address, key, storage_value); + cache.set_storage_initial_value(contract_address, key, storage_value); } - let value = self.cache.get_storage_at(contract_address, key).unwrap_or_else(|| { + let value = cache.get_storage_at(contract_address, key).unwrap_or_else(|| { panic!("Cannot retrieve '{contract_address:?}' and '{key:?}' from the cache.") }); Ok(*value) } - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { - if self.cache.get_nonce_at(contract_address).is_none() { + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { + let mut cache = self.cache.borrow_mut(); + + if cache.get_nonce_at(contract_address).is_none() { let nonce = self.state.get_nonce_at(contract_address)?; - self.cache.set_nonce_initial_value(contract_address, nonce); + cache.set_nonce_initial_value(contract_address, nonce); } - let nonce = self - .cache + let nonce = cache .get_nonce_at(contract_address) .unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache.")); + Ok(*nonce) } - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { - if self.cache.get_class_hash_at(contract_address).is_none() { + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { + let mut cache = self.cache.borrow_mut(); + + if cache.get_class_hash_at(contract_address).is_none() { let class_hash = self.state.get_class_hash_at(contract_address)?; - self.cache.set_class_hash_initial_value(contract_address, class_hash); + cache.set_class_hash_initial_value(contract_address, class_hash); } - let class_hash = self - .cache + let class_hash = cache .get_class_hash_at(contract_address) .unwrap_or_else(|| panic!("Cannot retrieve '{contract_address:?}' from the cache.")); Ok(*class_hash) @@ -205,24 +215,25 @@ impl StateReader for CachedState { #[allow(clippy::map_entry)] // Clippy solution don't work because it required two mutable ref to self // Could probably be solved with interior mutability - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult { - if !self.class_hash_to_class.contains_key(&class_hash) { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { + let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); + + if !class_hash_to_class.contains_key(&class_hash) { let contract_class = self.global_class_hash_to_class().cache_get(&class_hash).cloned(); match contract_class { Some(contract_class_from_global_cache) => { - self.class_hash_to_class.insert(class_hash, contract_class_from_global_cache); + class_hash_to_class.insert(class_hash, contract_class_from_global_cache); } None => { let contract_class_from_db = self.state.get_compiled_contract_class(class_hash)?; - self.class_hash_to_class.insert(class_hash, contract_class_from_db); + class_hash_to_class.insert(class_hash, contract_class_from_db); } } } - let contract_class = self - .class_hash_to_class + let contract_class = class_hash_to_class .get(&class_hash) .cloned() .expect("The class hash must appear in the cache."); @@ -230,14 +241,15 @@ impl StateReader for CachedState { Ok(contract_class) } - fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult { - if self.cache.get_compiled_class_hash(class_hash).is_none() { + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { + let mut cache = self.cache.borrow_mut(); + + if cache.get_compiled_class_hash(class_hash).is_none() { let compiled_class_hash = self.state.get_compiled_class_hash(class_hash)?; - self.cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash); + cache.set_compiled_class_hash_initial_value(class_hash, compiled_class_hash); } - let compiled_class_hash = self - .cache + let compiled_class_hash = cache .get_compiled_class_hash(class_hash) .unwrap_or_else(|| panic!("Cannot retrieve '{class_hash:?}' from the cache.")); Ok(*compiled_class_hash) @@ -251,7 +263,7 @@ impl State for CachedState { key: StorageKey, value: StarkFelt, ) -> StateResult<()> { - self.cache.set_storage_value(contract_address, key, value); + self.cache.get_mut().set_storage_value(contract_address, key, value); Ok(()) } @@ -264,7 +276,7 @@ impl State for CachedState { usize::try_from(current_nonce.0)?.try_into().expect("Failed to convert usize to u64."); let next_nonce_val = 1_u64 + current_nonce_as_u64; let next_nonce = Nonce(StarkFelt::from(next_nonce_val)); - self.cache.set_nonce_value(contract_address, next_nonce); + self.cache.get_mut().set_nonce_value(contract_address, next_nonce); Ok(()) } @@ -278,7 +290,7 @@ impl State for CachedState { return Err(StateError::OutOfRangeContractAddress); } - self.cache.set_class_hash_write(contract_address, class_hash); + self.cache.get_mut().set_class_hash_write(contract_address, class_hash); Ok(()) } @@ -287,7 +299,7 @@ impl State for CachedState { class_hash: ClassHash, contract_class: ContractClass, ) -> StateResult<()> { - self.class_hash_to_class.insert(class_hash, contract_class); + self.class_hash_to_class.get_mut().insert(class_hash, contract_class); Ok(()) } @@ -296,7 +308,7 @@ impl State for CachedState { class_hash: ClassHash, compiled_class_hash: CompiledClassHash, ) -> StateResult<()> { - self.cache.set_compiled_class_hash_write(class_hash, compiled_class_hash); + self.cache.get_mut().set_compiled_class_hash_write(class_hash, compiled_class_hash); Ok(()) } @@ -307,7 +319,7 @@ impl State for CachedState { self.update_initial_values_of_write_only_access() .unwrap_or_else(|_| panic!("Cannot convert stateDiff to CommitmentStateDiff.")); - let state_cache = &self.cache; + let state_cache = self.cache.borrow(); let class_hash_updates = state_cache.get_class_hash_updates(); let storage_diffs = state_cache.get_storage_updates(); let nonces = state_cache.get_nonce_updates(); @@ -501,26 +513,26 @@ impl<'a, S: State + ?Sized> MutRefState<'a, S> { /// Proxies inner object to expose `State` functionality. impl<'a, S: State + ?Sized> StateReader for MutRefState<'a, S> { fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult { self.0.get_storage_at(contract_address, key) } - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { self.0.get_nonce_at(contract_address) } - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { self.0.get_class_hash_at(contract_address) } - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { self.0.get_compiled_contract_class(class_hash) } - fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { self.0.get_compiled_class_hash(class_hash) } } @@ -591,8 +603,8 @@ impl<'a, S: StateReader> TransactionalState<'a, S> { .. } = self; StagedTransactionalState { - cache, - class_hash_to_class, + cache: cache.into_inner(), + class_hash_to_class: class_hash_to_class.into_inner(), global_class_hash_to_class, tx_executed_class_hashes, tx_visited_storage_entries, @@ -604,10 +616,10 @@ impl<'a, S: StateReader> TransactionalState<'a, S> { /// Commits changes in the child (wrapping) state to its parent. pub fn commit(self) { let state = self.state.0; - let child_cache = self.cache; + let child_cache = self.cache.into_inner(); state.update_cache(child_cache); state.update_contract_class_caches( - self.class_hash_to_class, + self.class_hash_to_class.into_inner(), self.global_class_hash_to_class, ); state.update_visited_pcs_cache(&self.visited_pcs); @@ -810,7 +822,7 @@ pub const GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST: usize = 100; impl GlobalContractCache { /// Locks the cache for atomic access. Although conceptually shared, writing to this cache is /// only possible for one writer at a time. - pub fn lock(&mut self) -> LockedContractClassCache<'_> { + pub fn lock(&self) -> LockedContractClassCache<'_> { self.0.lock().expect("Global contract cache is poisoned.") } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index a8467c43cb..0a55238ab7 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -23,17 +23,17 @@ fn set_initial_state_values( class_hash_initial_values: HashMap, storage_initial_values: HashMap, ) { - assert!(state.cache == StateCache::default(), "Cache already initialized."); + assert!(*state.cache.borrow() == StateCache::default(), "Cache already initialized."); - state.class_hash_to_class = class_hash_to_class; - state.cache.class_hash_initial_values.extend(class_hash_initial_values); - state.cache.nonce_initial_values.extend(nonce_initial_values); - state.cache.storage_initial_values.extend(storage_initial_values); + state.class_hash_to_class.replace(class_hash_to_class); + state.cache.get_mut().class_hash_initial_values.extend(class_hash_initial_values); + state.cache.get_mut().nonce_initial_values.extend(nonce_initial_values); + state.cache.get_mut().storage_initial_values.extend(storage_initial_values); } #[test] fn get_uninitialized_storage_value() { - let mut state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); let key = StorageKey(patricia_key!("0x10")); @@ -98,7 +98,7 @@ fn cast_between_storage_mapping_types() { #[test] fn get_uninitialized_value() { - let mut state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let contract_address = contract_address!("0x1"); assert_eq!(state.get_nonce_at(contract_address).unwrap(), Nonce::default()); @@ -140,7 +140,7 @@ fn get_and_increment_nonce() { fn get_contract_class() { // Positive flow. let existing_class_hash = class_hash!(TEST_CLASS_HASH); - let mut state = deprecated_create_test_state(); + let state = deprecated_create_test_state(); assert_eq!( state.get_compiled_contract_class(existing_class_hash).unwrap(), get_test_contract_class() @@ -156,7 +156,7 @@ fn get_contract_class() { #[test] fn get_uninitialized_class_hash_value() { - let mut state: CachedState = CachedState::default(); + let state: CachedState = CachedState::default(); let valid_contract_address = contract_address!("0x1"); assert_eq!(state.get_class_hash_at(valid_contract_address).unwrap(), ClassHash::default()); @@ -382,22 +382,22 @@ fn test_state_changes_merge( fn global_contract_cache_is_used() { // Initialize the global cache with a single class, and initialize an empty state with this // cache. - let mut global_cache = GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST); + let global_cache = GlobalContractCache::new(GLOBAL_CONTRACT_CACHE_SIZE_FOR_TEST); let class_hash = class_hash!(TEST_CLASS_HASH); let contract_class = get_test_contract_class(); global_cache.lock().cache_set(class_hash, contract_class.clone()); assert_eq!(global_cache.lock().cache_size(), 1); - let mut state = CachedState::new(DictStateReader::default(), global_cache.clone()); + let state = CachedState::new(DictStateReader::default(), global_cache.clone()); // Assert local cache is initialized empty even if global cache is not empty. - assert!(state.class_hash_to_class.get(&class_hash).is_none()); + assert!(state.class_hash_to_class.borrow().get(&class_hash).is_none()); // Check state uses the global cache. assert_eq!(state.get_compiled_contract_class(class_hash).unwrap(), contract_class); assert_eq!(global_cache.lock().cache_hits().unwrap(), 1); assert_eq!(global_cache.lock().cache_size(), 1); // Verify local cache is also updated. - assert_eq!(state.class_hash_to_class.get(&class_hash).unwrap(), &contract_class); + assert_eq!(state.class_hash_to_class.borrow().get(&class_hash).unwrap(), &contract_class); // Idempotency: getting the same class again uses the local cache. assert_eq!(state.get_compiled_contract_class(class_hash).unwrap(), contract_class); diff --git a/crates/blockifier/src/state/state_api.rs b/crates/blockifier/src/state/state_api.rs index 8999f2ef71..e70deccad8 100644 --- a/crates/blockifier/src/state/state_api.rs +++ b/crates/blockifier/src/state/state_api.rs @@ -27,24 +27,24 @@ pub trait StateReader { /// its address). /// Default: 0 for an uninitialized contract address. fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult; /// Returns the nonce of the given contract instance. /// Default: 0 for an uninitialized contract address. - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult; + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult; /// Returns the class hash of the contract class at the given contract instance. /// Default: 0 (uninitialized class hash) for an uninitialized contract address. - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult; + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult; /// Returns the contract class of the given class hash. - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult; + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult; /// Returns the compiled class hash of the given class hash. - fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult; + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult; /// Returns the storage value representing the balance (in fee token) at the given address. // TODO(Dori, 1/7/2023): When a standard representation for large integers is set, change the diff --git a/crates/blockifier/src/test_utils/dict_state_reader.rs b/crates/blockifier/src/test_utils/dict_state_reader.rs index 25f486064c..0337403c09 100644 --- a/crates/blockifier/src/test_utils/dict_state_reader.rs +++ b/crates/blockifier/src/test_utils/dict_state_reader.rs @@ -21,7 +21,7 @@ pub struct DictStateReader { impl StateReader for DictStateReader { fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult { @@ -30,12 +30,12 @@ impl StateReader for DictStateReader { Ok(value) } - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { let nonce = self.address_to_nonce.get(&contract_address).copied().unwrap_or_default(); Ok(nonce) } - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { let contract_class = self.class_hash_to_class.get(&class_hash).cloned(); match contract_class { Some(contract_class) => Ok(contract_class), @@ -43,14 +43,14 @@ impl StateReader for DictStateReader { } } - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { let class_hash = self.address_to_class_hash.get(&contract_address).copied().unwrap_or_default(); Ok(class_hash) } fn get_compiled_class_hash( - &mut self, + &self, class_hash: ClassHash, ) -> StateResult { let compiled_class_hash = diff --git a/crates/native_blockifier/src/state_readers/papyrus_state.rs b/crates/native_blockifier/src/state_readers/papyrus_state.rs index 259c6a7246..c72b8a1645 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state.rs @@ -36,7 +36,7 @@ impl PapyrusReader { // Currently unused - will soon replace the same `impl` for `PapyrusStateReader`. impl StateReader for PapyrusReader { fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult { @@ -47,7 +47,7 @@ impl StateReader for PapyrusReader { .map_err(|error| StateError::StateReadError(error.to_string())) } - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { let state_number = StateNumber(self.latest_block); match self .reader()? @@ -60,7 +60,7 @@ impl StateReader for PapyrusReader { } } - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { let state_number = StateNumber(self.latest_block); match self .reader()? @@ -75,7 +75,7 @@ impl StateReader for PapyrusReader { /// Returns a V1 contract if found, or a V0 contract if a V1 contract is not /// found, or an `Error` otherwise. - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { let state_number = StateNumber(self.latest_block); let class_declaration_block_number = self .reader()? @@ -112,10 +112,7 @@ impl StateReader for PapyrusReader { } } - fn get_compiled_class_hash( - &mut self, - _class_hash: ClassHash, - ) -> StateResult { + fn get_compiled_class_hash(&self, _class_hash: ClassHash) -> StateResult { todo!() } } diff --git a/crates/native_blockifier/src/state_readers/py_state_reader.rs b/crates/native_blockifier/src/state_readers/py_state_reader.rs index 23d92b7eac..169f84a845 100644 --- a/crates/native_blockifier/src/state_readers/py_state_reader.rs +++ b/crates/native_blockifier/src/state_readers/py_state_reader.rs @@ -32,7 +32,7 @@ impl PyStateReader { impl StateReader for PyStateReader { fn get_storage_at( - &mut self, + &self, contract_address: ContractAddress, key: StorageKey, ) -> StateResult { @@ -44,7 +44,7 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_nonce_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { Python::with_gil(|py| -> PyResult { let args = (ON_CHAIN_STORAGE_DOMAIN, PyFelt::from(contract_address)); self.state_reader_proxy.as_ref(py).call_method1("get_nonce_at", args)?.extract() @@ -53,7 +53,7 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_class_hash_at(&mut self, contract_address: ContractAddress) -> StateResult { + fn get_class_hash_at(&self, contract_address: ContractAddress) -> StateResult { Python::with_gil(|py| -> PyResult { let args = (PyFelt::from(contract_address),); self.state_reader_proxy.as_ref(py).call_method1("get_class_hash_at", args)?.extract() @@ -62,7 +62,7 @@ impl StateReader for PyStateReader { .map_err(|err| StateError::StateReadError(err.to_string())) } - fn get_compiled_contract_class(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { Python::with_gil(|py| -> Result { let args = (PyFelt::from(class_hash),); let py_raw_compiled_class: PyRawCompiledClass = self @@ -82,7 +82,7 @@ impl StateReader for PyStateReader { }) } - fn get_compiled_class_hash(&mut self, class_hash: ClassHash) -> StateResult { + fn get_compiled_class_hash(&self, class_hash: ClassHash) -> StateResult { Python::with_gil(|py| -> PyResult { let args = (PyFelt::from(class_hash),); self.state_reader_proxy From eee9e4d15d9315ba5130d0b37c8ea33c703d171b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Delabrouille?= Date: Wed, 10 Jan 2024 23:44:39 +0100 Subject: [PATCH 2/3] chore: remove clippy allow --- crates/blockifier/src/state/cached_state.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index b9f7dd8012..cc63ef0493 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -212,23 +212,21 @@ impl StateReader for CachedState { Ok(*class_hash) } - #[allow(clippy::map_entry)] - // Clippy solution don't work because it required two mutable ref to self - // Could probably be solved with interior mutability fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); - if !class_hash_to_class.contains_key(&class_hash) { + if let std::collections::hash_map::Entry::Vacant(e) = class_hash_to_class.entry(class_hash) + { let contract_class = self.global_class_hash_to_class().cache_get(&class_hash).cloned(); match contract_class { Some(contract_class_from_global_cache) => { - class_hash_to_class.insert(class_hash, contract_class_from_global_cache); + e.insert(contract_class_from_global_cache); } None => { let contract_class_from_db = self.state.get_compiled_contract_class(class_hash)?; - class_hash_to_class.insert(class_hash, contract_class_from_db); + e.insert(contract_class_from_db); } } } From d3dc3dcc4dc0351940f6685458c0ea63feb60685 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Delabrouille?= Date: Thu, 15 Feb 2024 12:02:12 +0100 Subject: [PATCH 3/3] fix: small review fix --- crates/blockifier/src/state/cached_state.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index cc63ef0493..1dd0ef2480 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -29,6 +29,7 @@ pub type ContractClassMapping = HashMap; pub struct CachedState { pub state: S, // Invariant: read/write access is managed by CachedState. + // Using interior mutability to update caches during `State`'s immutable getters. cache: RefCell, class_hash_to_class: RefCell, // Invariant: managed by CachedState. @@ -215,18 +216,19 @@ impl StateReader for CachedState { fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { let class_hash_to_class = &mut *self.class_hash_to_class.borrow_mut(); - if let std::collections::hash_map::Entry::Vacant(e) = class_hash_to_class.entry(class_hash) + if let std::collections::hash_map::Entry::Vacant(vacant_entry) = + class_hash_to_class.entry(class_hash) { let contract_class = self.global_class_hash_to_class().cache_get(&class_hash).cloned(); match contract_class { Some(contract_class_from_global_cache) => { - e.insert(contract_class_from_global_cache); + vacant_entry.insert(contract_class_from_global_cache); } None => { let contract_class_from_db = self.state.get_compiled_contract_class(class_hash)?; - e.insert(contract_class_from_db); + vacant_entry.insert(contract_class_from_db); } } }