Skip to content
This repository has been archived by the owner on Aug 21, 2024. It is now read-only.

Merge main contract segmentation into main #1331

Merged
merged 8 commits into from
Jan 14, 2024
80 changes: 68 additions & 12 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,8 @@ impl ContractClassV1 {
/// This is an empiric measurement of several bytecode lengths, which constitutes as the
/// dominant factor in it.
fn estimate_casm_hash_computation_resources(&self) -> VmExecutionResources {
let bytecode_length = self.bytecode_length() as f64;
let n_steps = (503.0 + bytecode_length * 5.7) as usize;
let n_poseidon_builtins = (10.9 + bytecode_length * 0.5) as usize;

VmExecutionResources {
n_steps,
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(
POSEIDON_BUILTIN_NAME.to_string(),
n_poseidon_builtins,
)]),
}
// TODO(lior): Use `bytecode_segment_lengths` from the class.
estimate_casm_hash_computation_resources(NestedIntList::Leaf(self.bytecode_length()))
}

pub fn try_from_json_string(raw_contract_class: &str) -> Result<ContractClassV1, ProgramError> {
Expand All @@ -195,6 +185,72 @@ impl ContractClassV1 {
}
}

// TODO(lior): Remove this and use `NestedIntList` from the cairo compiler repo once available.
#[derive(Clone, Debug)]
pub enum NestedIntList {
Leaf(usize),
Node(Vec<NestedIntList>),
}

/// Returns the estimated VM resources required for computing Casm hash (for Cairo 1 contracts).
///
/// Note: the function focuses on the bytecode size, and currently ignores the cost handling the
/// class entry points.
pub fn estimate_casm_hash_computation_resources(
bytecode_segment_lengths: NestedIntList,
) -> VmExecutionResources {
// The constants in this function were computed by running the Casm code on a few values
// of `bytecode_segment_lengths`.
match bytecode_segment_lengths {
NestedIntList::Leaf(length) => {
// The entire contract is a single segment (old Sierra contracts).
&VmExecutionResources {
n_steps: 474,
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(POSEIDON_BUILTIN_NAME.to_string(), 10)]),
} + &poseidon_hash_many_cost(length)
}
NestedIntList::Node(segments) => {
// The contract code is segmented by its functions.
let mut execution_resources = VmExecutionResources {
n_steps: 491,
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(POSEIDON_BUILTIN_NAME.to_string(), 11)]),
};
let base_segment_cost = VmExecutionResources {
n_steps: 24,
n_memory_holes: 1,
builtin_instance_counter: HashMap::from([(POSEIDON_BUILTIN_NAME.to_string(), 1)]),
};
for segment in segments {
let NestedIntList::Leaf(length) = segment else {
panic!(
"Estimating hash cost is only supported for segmentation depth at most 1."
);
};
execution_resources += &poseidon_hash_many_cost(length);
execution_resources += &base_segment_cost;
}
execution_resources
}
}
}

/// Returns the VM resources required for running `poseidon_hash_many` in the Starknet OS.
fn poseidon_hash_many_cost(data_length: usize) -> VmExecutionResources {
VmExecutionResources {
n_steps: (data_length / 10) * 55
+ ((data_length % 10) / 2) * 18
+ (data_length % 2) * 3
+ 21,
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(
POSEIDON_BUILTIN_NAME.to_string(),
data_length / 2 + 1,
)]),
}
}

#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ContractClassV1Inner {
pub program: Program,
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ pub struct ExecutionResources {
pub syscall_counter: SyscallCounter,
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct EntryPointExecutionContext {
pub block_context: BlockContext,
pub account_tx_context: AccountTransactionContext,
Expand Down
53 changes: 51 additions & 2 deletions crates/blockifier/src/execution/entry_point_execution.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashSet;

use cairo_felt::Felt252;
use cairo_vm::serde::deserialize_program::BuiltinName;
use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable};
Expand Down Expand Up @@ -52,6 +54,11 @@ pub fn execute_entry_point_call(
resources: &mut ExecutionResources,
context: &mut EntryPointExecutionContext,
) -> EntryPointExecutionResult<CallInfo> {
// Fetch the class hash from `call`.
let class_hash = call.class_hash.ok_or(EntryPointExecutionError::InternalError(
"Class hash must not be None when executing an entry point.".into(),
))?;

let VmExecutionContext {
mut runner,
mut vm,
Expand All @@ -74,7 +81,8 @@ pub fn execute_entry_point_call(
let previous_vm_resources = syscall_handler.resources.vm_resources.clone();

// Execute.
let program_segment_size = contract_class.bytecode_length() + program_extra_data_length;
let bytecode_length = contract_class.bytecode_length();
let program_segment_size = bytecode_length + program_extra_data_length;
run_entry_point(
&mut vm,
&mut runner,
Expand All @@ -84,6 +92,15 @@ pub fn execute_entry_point_call(
program_segment_size,
)?;

// Collect the set PC values that were visited during the entry point execution.
register_visited_pcs(
&mut vm,
syscall_handler.state,
class_hash,
program_segment_size,
bytecode_length,
)?;

let call_info = finalize_execution(
vm,
runner,
Expand All @@ -101,6 +118,38 @@ pub fn execute_entry_point_call(
Ok(call_info)
}

// Collects the set PC values that were visited during the entry point execution.
fn register_visited_pcs(
vm: &mut VirtualMachine,
state: &mut dyn State,
class_hash: starknet_api::core::ClassHash,
program_segment_size: usize,
bytecode_length: usize,
) -> Result<(), EntryPointExecutionError> {
let mut class_visited_pcs = HashSet::new();
// Relocate the trace, putting the program segment at address 1 and the execution segment right
// after it.
// TODO(lior): Avoid unnecessary relocation once the VM has a non-relocated `get_trace()`
// function.
vm.relocate_trace(&[1, 1 + program_segment_size])?;
for trace_entry in vm.get_relocated_trace()? {
let pc = trace_entry.pc;
if pc < 1 {
return Err(EntryPointExecutionError::InternalError(format!(
"Invalid PC value {pc} in trace."
)));
}
let real_pc = pc - 1;
// Jumping to a PC that is not inside the bytecode is possible. For example, to obtain
// the builtin costs. Filter out these values.
if real_pc < bytecode_length {
class_visited_pcs.insert(real_pc);
}
}
state.add_visited_pcs(class_hash, &class_visited_pcs);
Ok(())
}

pub fn initialize_execution_context<'a>(
call: CallEntryPoint,
contract_class: &'a ContractClassV1,
Expand All @@ -114,7 +163,7 @@ pub fn initialize_execution_context<'a>(
let proof_mode = false;
let mut runner = CairoRunner::new(&contract_class.0.program, "starknet", proof_mode)?;

let trace_enabled = false;
let trace_enabled = true;
let mut vm = VirtualMachine::new(trace_enabled);

// Initialize program with all builtins.
Expand Down
5 changes: 5 additions & 0 deletions crates/blockifier/src/execution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use cairo_vm::types::errors::math_errors::MathError;
use cairo_vm::vm::errors::cairo_run_errors::CairoRunError;
use cairo_vm::vm::errors::memory_errors::MemoryError;
use cairo_vm::vm::errors::runner_errors::RunnerError;
use cairo_vm::vm::errors::trace_errors::TraceError;
use cairo_vm::vm::errors::vm_errors::{VirtualMachineError, HINT_ERROR_STR};
use num_bigint::{BigInt, TryFromBigIntError};
use starknet_api::core::{ContractAddress, EntryPointSelector};
Expand Down Expand Up @@ -124,6 +125,8 @@ impl VirtualMachineExecutionError {
pub enum EntryPointExecutionError {
#[error("Execution failed. Failure reason: {}.", format_panic_data(.error_data))]
ExecutionFailed { error_data: Vec<StarkFelt> },
#[error("Internal error: {0}")]
InternalError(String),
#[error("Invalid input: {input_descriptor}; {info}")]
InvalidExecutionInput { input_descriptor: String, info: String },
#[error(transparent)]
Expand All @@ -134,6 +137,8 @@ pub enum EntryPointExecutionError {
RecursionDepthExceeded,
#[error(transparent)]
StateError(#[from] StateError),
#[error(transparent)]
TraceError(#[from] TraceError),
/// Gathers all errors from running the Cairo VM, excluding hints.
#[error(transparent)]
VirtualMachineExecutionError(#[from] VirtualMachineExecutionError),
Expand Down
36 changes: 32 additions & 4 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub struct CachedState<S: StateReader> {
class_hash_to_class: ContractClassMapping,
// Invariant: managed by CachedState.
global_class_hash_to_class: GlobalContractCache,
/// A map from class hash to the set of PC values that were visited in the class.
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
}

impl<S: StateReader> CachedState<S> {
Expand All @@ -41,6 +43,7 @@ impl<S: StateReader> CachedState<S> {
cache: StateCache::default(),
class_hash_to_class: HashMap::default(),
global_class_hash_to_class,
visited_pcs: HashMap::default(),
}
}

Expand Down Expand Up @@ -135,6 +138,12 @@ impl<S: StateReader> CachedState<S> {
self.global_class_hash_to_class = global_contract_cache;
}

pub fn update_visited_pcs_cache(&mut self, visited_pcs: &HashMap<ClassHash, HashSet<usize>>) {
for (class_hash, class_visited_pcs) in visited_pcs {
self.add_visited_pcs(*class_hash, class_visited_pcs);
}
}

/// Updates cache with initial cell values for write-only access.
/// If written values match the original, the cell is unchanged and not counted as a
/// storage-change for fee calculation.
Expand Down Expand Up @@ -340,6 +349,10 @@ impl<S: StateReader> State for CachedState<S> {
address_to_nonce: IndexMap::from_iter(nonces),
}
}

fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>) {
self.visited_pcs.entry(class_hash).or_default().extend(pcs);
}
}

#[cfg(any(feature = "testing", test))]
Expand All @@ -350,6 +363,7 @@ impl Default for CachedState<crate::test_utils::dict_state_reader::DictStateRead
cache: Default::default(),
class_hash_to_class: Default::default(),
global_class_hash_to_class: Default::default(),
visited_pcs: Default::default(),
}
}
}
Expand Down Expand Up @@ -579,6 +593,10 @@ impl<'a, S: State + ?Sized> State for MutRefState<'a, S> {
) -> StateResult<()> {
self.0.set_compiled_class_hash(class_hash, compiled_class_hash)
}

fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>) {
self.0.add_visited_pcs(class_hash, pcs)
}
}

pub type TransactionalState<'a, S> = CachedState<MutRefState<'a, CachedState<S>>>;
Expand All @@ -591,14 +609,20 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
tx_executed_class_hashes: HashSet<ClassHash>,
tx_visited_storage_entries: HashSet<StorageEntry>,
) -> StagedTransactionalState {
let TransactionalState { cache, class_hash_to_class, global_class_hash_to_class, .. } =
self;
let TransactionalState {
cache,
class_hash_to_class,
global_class_hash_to_class,
visited_pcs,
..
} = self;
StagedTransactionalState {
cache,
class_hash_to_class,
global_class_hash_to_class,
tx_executed_class_hashes,
tx_visited_storage_entries,
visited_pcs,
}
}

Expand All @@ -607,8 +631,11 @@ impl<'a, S: StateReader> TransactionalState<'a, S> {
let state = self.state.0;
let child_cache = self.cache;
state.update_cache(child_cache);
state
.update_contract_class_caches(self.class_hash_to_class, self.global_class_hash_to_class)
state.update_contract_class_caches(
self.class_hash_to_class,
self.global_class_hash_to_class,
);
state.update_visited_pcs_cache(&self.visited_pcs);
}

/// Drops `self`.
Expand All @@ -626,6 +653,7 @@ pub struct StagedTransactionalState {
// Maintained for counting purposes.
pub tx_executed_class_hashes: HashSet<ClassHash>,
pub tx_visited_storage_entries: HashSet<StorageEntry>,
pub visited_pcs: HashMap<ClassHash, HashSet<usize>>,
}

/// Holds uncommitted changes induced on Starknet contracts.
Expand Down
7 changes: 7 additions & 0 deletions crates/blockifier/src/state/state_api.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashSet;

use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::hash::StarkFelt;
use starknet_api::state::StorageKey;
Expand Down Expand Up @@ -103,4 +105,9 @@ pub trait State: StateReader {
) -> StateResult<()>;

fn to_state_diff(&mut self) -> CommitmentStateDiff;

/// Marks the given set of PC values as visited for the given class hash.
// TODO(lior): Once we have a BlockResources object, move this logic there. Make sure reverted
// entry points do not affect the final set of PCs.
fn add_visited_pcs(&mut self, class_hash: ClassHash, pcs: &HashSet<usize>);
}
15 changes: 14 additions & 1 deletion crates/native_blockifier/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub mod py_l1_handler;
pub mod py_state_diff;
#[cfg(any(feature = "testing", test))]
pub mod py_test_utils;
// TODO(Dori, 1/4/2023): If and when supported in the Python build environment, use #[cfg(test)].
pub mod py_testing_wrappers;
pub mod py_transaction;
pub mod py_transaction_execution_info;
pub mod py_utils;
Expand All @@ -27,7 +29,10 @@ use pyo3::prelude::*;
use storage::StorageConfig;

use crate::py_state_diff::PyStateDiff;
use crate::py_utils::raise_error_for_testing;
use crate::py_testing_wrappers::{
estimate_casm_hash_computation_resources_for_testing_list,
estimate_casm_hash_computation_resources_for_testing_single, raise_error_for_testing,
};

#[pymodule]
fn native_blockifier(py: Python<'_>, py_module: &PyModule) -> PyResult<()> {
Expand All @@ -53,6 +58,14 @@ fn native_blockifier(py: Python<'_>, py_module: &PyModule) -> PyResult<()> {
// TODO(Dori, 1/4/2023): If and when supported in the Python build environment, gate this code
// with #[cfg(test)].
py_module.add_function(wrap_pyfunction!(raise_error_for_testing, py)?)?;
py_module.add_function(wrap_pyfunction!(
estimate_casm_hash_computation_resources_for_testing_list,
py
)?)?;
py_module.add_function(wrap_pyfunction!(
estimate_casm_hash_computation_resources_for_testing_single,
py
)?)?;

Ok(())
}
Expand Down
4 changes: 3 additions & 1 deletion crates/native_blockifier/src/py_block_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ impl PyBlockExecutor {
self.tx_executor().execute(tx, raw_contract_class, charge_fee)
}

pub fn finalize(&mut self, is_pending_block: bool) -> PyStateDiff {
/// Returns the state diff and a list of contract class hash with the corresponding list of
/// visited PC values.
pub fn finalize(&mut self, is_pending_block: bool) -> (PyStateDiff, Vec<(PyFelt, Vec<usize>)>) {
log::debug!("Finalizing execution...");
let finalized_state = self.tx_executor().finalize(is_pending_block);
log::debug!("Finalized execution.");
Expand Down
Loading
Loading