Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builtin costs rework #837

Merged
merged 16 commits into from
Oct 10, 2024
5 changes: 5 additions & 0 deletions programs/benches/factorial_2M.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ typedef struct factorial_return_values
} result;
} factorial_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(factorial_return_values_t*, uint64_t)
__attribute__((weakref("_mlir_ciface_factorial_2M::factorial_2M::main(f1)")));
Expand All @@ -25,6 +26,10 @@ int main()
{
factorial_return_values_t return_values;

uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};
edg-l marked this conversation as resolved.
Show resolved Hide resolved

builtin_costs = &BuiltinCosts[0];

run_bench(&return_values, 0);
assert(return_values.result.discriminant == 0);

Expand Down
5 changes: 5 additions & 0 deletions programs/benches/fib_2M.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ typedef struct fib_return_values
} result;
} fib_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(fib_return_values_t *, uint64_t)
__attribute__((weakref("_mlir_ciface_fib_2M::fib_2M::main(f1)")));


int main()
{
uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};

builtin_costs = &BuiltinCosts[0];

fib_return_values_t return_values;

run_bench(&return_values, 0);
Expand Down
5 changes: 5 additions & 0 deletions programs/benches/logistic_map.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ typedef struct map_return_values
} result;
} map_return_values_t;

extern uint64_t* builtin_costs;

static void run_bench(map_return_values_t *, uint64_t)
__attribute__((weakref("_mlir_ciface_logistic_map::logistic_map::main(f2)")));


int main()
{
uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604};

builtin_costs = &BuiltinCosts[0];

map_return_values_t return_values;

run_bench(&return_values, 0);
Expand Down
28 changes: 26 additions & 2 deletions src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ use melior::{
arith::CmpiPredicate,
cf, func, index,
llvm::{self, LoadStoreOptions},
memref,
memref, ods,
},
ir::{
attribute::{
Expand Down Expand Up @@ -135,6 +135,30 @@ pub fn compile(
}
}

{
// Add the builtin_costs global.
// We always add it because symbol look up otherwise can panic.
let region = Region::new();
let location = Location::unknown(context);
let block = region.append_block(Block::new(&[]));
let value = block.append_op_result(
ods::llvm::mlir_zero(context, llvm::r#type::pointer(context, 0), location).into(),
)?;
block.append_operation(melior::dialect::llvm::r#return(Some(value), location));

module.body().append_operation(
ods::llvm::mlir_global(
context,
region,
TypeAttribute::new(llvm::r#type::pointer(context, 0)),
StringAttribute::new(context, "builtin_costs"),
Attribute::parse(context, "#llvm.linkage<external>").unwrap(),
location,
)
.into(),
);
}
edg-l marked this conversation as resolved.
Show resolved Hide resolved

// Sierra programs have the following structure:
// 1. Type declarations, one per line.
// 2. Libfunc declarations, one per line.
Expand Down Expand Up @@ -446,7 +470,7 @@ fn compile_func(
initial_state,
|statement_idx, mut state| {
if let Some(gas_metadata) = metadata.get::<GasMetadata>() {
let gas_cost = gas_metadata.get_gas_cost_for_statement(statement_idx);
let gas_cost = gas_metadata.get_gas_costs_for_statement(statement_idx);
metadata.remove::<GasCost>();
metadata.insert(GasCost(gas_cost));
}
Expand Down
70 changes: 50 additions & 20 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
execution_result::{BuiltinStats, ExecutionResult},
starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler},
types::TypeBuilder,
utils::{libc_free, RangeExt},
utils::{libc_free, BuiltinCosts, RangeExt},
values::Value,
};
use bumpalo::Bump;
Expand Down Expand Up @@ -69,6 +69,7 @@ extern "C" {
fn invoke_dynamic(
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
function_ptr: *const c_void,
builtin_costs_ptr: Option<*mut c_void>,
function_signature: &FunctionSignature,
args: &[Value],
gas: u128,
Expand Down Expand Up @@ -141,6 +142,15 @@ fn invoke_dynamic(
previous_syscall_handler
});

// Order matters, for the libfunc impl
let builtin_costs: [u64; 7] = BuiltinCosts::default().into();

if let Some(builtin_costs_ptr) = builtin_costs_ptr {
unsafe {
*builtin_costs_ptr.cast() = builtin_costs.as_ptr();
}
}

// Generate argument list.
let mut iter = args.iter();
for item in function_signature.param_types.iter().filter_map(|type_id| {
Expand All @@ -166,6 +176,14 @@ fn invoke_dynamic(
(syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>)
.to_bytes(&mut invoke_data)?;
}
CoreTypeConcrete::BuiltinCosts(_) => {
// This builtin should never be an argument but just in case.
if let Some(builtin_costs_ptr) = builtin_costs_ptr {
builtin_costs_ptr.to_bytes(&mut invoke_data)?;
} else {
(builtin_costs.as_ptr()).to_bytes(&mut invoke_data)?;
}
}
type_info if type_info.is_builtin() => 0u64.to_bytes(&mut invoke_data)?,
type_info => JitValueWithInfoWrapper {
value: iter.next().unwrap(),
Expand Down Expand Up @@ -249,26 +267,38 @@ fn invoke_dynamic(
},
_ if type_info.is_builtin() => {
if !type_info.is_zst(registry)? {
let value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<u64>(return_ptr) },
None => ret_registers[0],
} as usize;

match type_info {
CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value,
CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value,
CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value,
CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value,
CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value,
CoreTypeConcrete::SegmentArena(_) => builtin_stats.segment_arena = value,
CoreTypeConcrete::RangeCheck96(_) => builtin_stats.range_check_96 = value,
CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => {
builtin_stats.circuit_add = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => {
builtin_stats.circuit_mul = value
if let CoreTypeConcrete::BuiltinCosts(_) = type_info {
// todo: should we use this value?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edg-l Can you create an issue to track this?

let _value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<*mut u64>(return_ptr) },
None => ret_registers[0] as *mut u64,
};
} else {
let value = match &mut return_ptr {
Some(return_ptr) => unsafe { *read_value::<u64>(return_ptr) },
None => ret_registers[0],
} as usize;

match type_info {
CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value,
CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value,
CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value,
CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value,
CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value,
CoreTypeConcrete::SegmentArena(_) => {
builtin_stats.segment_arena = value
}
CoreTypeConcrete::RangeCheck96(_) => {
builtin_stats.range_check_96 = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => {
builtin_stats.circuit_add = value
}
CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => {
builtin_stats.circuit_mul = value
}
_ => unreachable!("{type_id:?}"),
}
_ => unreachable!("{type_id:?}"),
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions src/executor/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl AotNativeExecutor {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
args,
available_gas,
Expand All @@ -103,6 +104,7 @@ impl AotNativeExecutor {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
args,
available_gas,
Expand All @@ -125,6 +127,7 @@ impl AotNativeExecutor {
ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id),
&[Value::Struct {
fields: vec![Value::Array(
Expand Down Expand Up @@ -152,6 +155,15 @@ impl AotNativeExecutor {
}
}

pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
unsafe {
self.library
.get::<*mut ()>(name.as_bytes())
.ok()
.map(|x| x.into_raw().into_raw())
}
}

fn extract_signature(&self, function_id: &FunctionId) -> &FunctionSignature {
&self.registry.get_function(function_id).unwrap().signature
}
Expand Down
30 changes: 30 additions & 0 deletions src/executor/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::{
types::TypeBuilder,
utils::{
decode_error_message, generate_function_name, get_integer_layout, libc_free, libc_malloc,
BuiltinCosts,
},
OptLevel,
};
Expand Down Expand Up @@ -97,6 +98,7 @@ enum BuiltinType {
CircuitMul,
Gas,
System,
BuiltinCosts,
}

impl AotContractExecutor {
Expand Down Expand Up @@ -209,12 +211,23 @@ impl AotContractExecutor {
function_id: &FunctionId,
args: &[Felt],
gas: Option<u128>,
builtin_costs: Option<BuiltinCosts>,
mut syscall_handler: impl StarknetSyscallHandler,
) -> Result<ContractExecutionResult> {
let arena = Bump::new();
let mut invoke_data = Vec::<u8>::new();

let function_ptr = self.find_function_ptr(function_id, true)?;
let builtin_costs_ptr = self.find_symbol_ptr("builtin_costs");

let builtin_costs = builtin_costs.unwrap_or_default();
let builtin_costs: [u64; 7] = builtin_costs.into();

if let Some(builtin_costs_ptr) = builtin_costs_ptr {
unsafe {
*builtin_costs_ptr.cast() = builtin_costs.as_ptr();
}
}

// it can vary from contract to contract thats why we need to store/ load it.
// substract 2, which are the gas and syscall builtin
Expand Down Expand Up @@ -320,6 +333,11 @@ impl AotContractExecutor {
let ptr = return_ptr.cast::<*mut ()>();
*return_ptr = unsafe { NonNull::new_unchecked(ptr.as_ptr().add(1)).cast() };
}
BuiltinType::BuiltinCosts => {
let ptr = return_ptr.cast::<*mut ()>();
*return_ptr = unsafe { NonNull::new_unchecked(ptr.as_ptr().add(1)).cast() };
// ptr holds the builtin costs, but they dont change, so its of no use, but we read to advance the ptr.
}
x => {
let value = unsafe { *read_value::<u64>(return_ptr) } as usize;

Expand All @@ -335,6 +353,7 @@ impl AotContractExecutor {
BuiltinType::CircuitMul => builtin_stats.circuit_mul = value,
BuiltinType::Gas => {}
BuiltinType::System => {}
BuiltinType::BuiltinCosts => {}
}
}
}
Expand Down Expand Up @@ -433,6 +452,15 @@ impl AotContractExecutor {
.into_raw()
})
}

pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
unsafe {
self.library
.get::<*mut ()>(name.as_bytes())
.ok()
.map(|x| x.into_raw().into_raw())
}
}
}

impl Drop for AotContractExecutor {
Expand Down Expand Up @@ -516,6 +544,7 @@ mod tests {
entrypoint_function_id,
&[2.into()],
Some(u64::MAX as u128),
None,
&mut StubSyscallHandler::default(),
)
.unwrap();
Expand All @@ -541,6 +570,7 @@ mod tests {
entrypoint_function_id,
&[],
Some(u64::MAX as u128),
None,
&mut StubSyscallHandler::default(),
)
.unwrap();
Expand Down
13 changes: 13 additions & 0 deletions src/executor/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl<'m> JitNativeExecutor<'m> {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id).unwrap(),
args,
available_gas,
Expand All @@ -102,6 +103,7 @@ impl<'m> JitNativeExecutor<'m> {
super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id).unwrap(),
args,
available_gas,
Expand All @@ -124,6 +126,7 @@ impl<'m> JitNativeExecutor<'m> {
ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_symbol_ptr("builtin_costs"),
self.extract_signature(function_id).unwrap(),
&[Value::Struct {
fields: vec![Value::Array(
Expand All @@ -145,6 +148,16 @@ impl<'m> JitNativeExecutor<'m> {
self.engine.lookup(&function_name) as *mut c_void
}

pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> {
let ptr = self.engine.lookup(name) as *mut c_void;

if ptr.is_null() {
None
} else {
Some(ptr)
}
}

fn extract_signature(&self, function_id: &FunctionId) -> Option<&FunctionSignature> {
self.program_registry()
.get_function(function_id)
Expand Down
Loading
Loading