Skip to content

Commit

Permalink
fix: dict gas (#567)
Browse files Browse the repository at this point in the history
* update runtime bindings of dict

* update build_squash include gas refund

* unskip tests
  • Loading branch information
greged93 authored May 3, 2024
1 parent 4e1b12a commit bd23b1c
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["rlib", "cdylib", "staticlib"]
starknet-types-core = { version = "0.1.0", default-features = false, features = [
"serde",
] }
cairo-lang-runner = "2.5.4"
cairo-lang-sierra-gas = "2.5.4"
libc = "0.2"
starknet-crypto = "0.6"
starknet-curve = "0.4"
Expand Down
40 changes: 31 additions & 9 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#![allow(non_snake_case)]

use cairo_lang_sierra_gas::core_libfunc_cost::{
DICT_SQUASH_REPEATED_ACCESS_COST, DICT_SQUASH_UNIQUE_KEY_COST,
};
use lazy_static::lazy_static;
use starknet_crypto::FieldElement;
use starknet_curve::AffinePoint;
Expand All @@ -11,6 +14,8 @@ lazy_static! {
"1809251394333065606848661391547535052811553607665798349986546028067936010240"
)
.unwrap();
pub static ref DICT_GAS_REFUND_PER_ACCESS: u64 =
(DICT_SQUASH_UNIQUE_KEY_COST.cost() - DICT_SQUASH_REPEATED_ACCESS_COST.cost()) as u64;
}

/// Based on `cairo-lang-runner`'s implementation.
Expand Down Expand Up @@ -152,7 +157,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void {
Box::into_raw(Box::<HashMap<[u8; 32], NonNull<std::ffi::c_void>>>::default()) as _
Box::into_raw(Box::<(HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64)>::default()) as _
}

/// Frees the dictionary.
Expand All @@ -163,28 +168,31 @@ pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void {
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_free(
ptr: *mut HashMap<[u8; 32], NonNull<std::ffi::c_void>>,
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
) {
let mut map = Box::from_raw(ptr);

// Free the entries manually.
for (_, entry) in map.drain() {
for (_, entry) in map.as_mut().0.drain() {
libc::free(entry.as_ptr().cast());
}
}

/// Gets the value for a given key, the returned pointer is null if not found.
/// Increments the access count.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_get(
ptr: *const HashMap<[u8; 32], NonNull<std::ffi::c_void>>,
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
key: &[u8; 32],
) -> *const std::ffi::c_void {
let map: &HashMap<[u8; 32], NonNull<std::ffi::c_void>> = &*ptr;
) -> *mut std::ffi::c_void {
let dict: &mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64) = &mut *ptr;
let map = &dict.0;
dict.1 += 1;

if let Some(v) = map.get(key) {
v.as_ptr()
Expand All @@ -201,12 +209,12 @@ pub unsafe extern "C" fn cairo_native__dict_get(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_insert(
ptr: *mut HashMap<[u8; 32], NonNull<std::ffi::c_void>>,
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
key: &[u8; 32],
value: NonNull<std::ffi::c_void>,
) -> *mut std::ffi::c_void {
let ptr = ptr.cast::<HashMap<[u8; 32], NonNull<std::ffi::c_void>>>();
let old_ptr = (*ptr).insert(*key, value);
let dict = &mut *ptr;
let old_ptr = dict.0.insert(*key, value);

if let Some(v) = old_ptr {
v.as_ptr()
Expand All @@ -215,6 +223,20 @@ pub unsafe extern "C" fn cairo_native__dict_insert(
}
}

/// Compute the total gas refund for the dictionary at squash time.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_gas_refund(
ptr: *const (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
) -> u64 {
let dict = &*ptr;
(dict.1 - dict.0.len() as u64) * *DICT_GAS_REFUND_PER_ACCESS
}

/// Compute `ec_point_from_x_nz(x)` and store it.
///
/// # Panics
Expand Down
31 changes: 28 additions & 3 deletions src/libfuncs/felt252_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use cairo_lang_sierra::{
program_registry::ProgramRegistry,
};
use melior::{
ir::{Block, Location},
dialect::arith,
ir::{r#type::IntegerType, Block, Location},
Context,
};

Expand Down Expand Up @@ -67,19 +68,43 @@ pub fn build_squash<'ctx, 'this>(
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
_metadata: &MetadataStorage,
metadata: &mut MetadataStorage,
_info: &SignatureOnlyConcreteLibfunc,
) -> Result<()> {
let range_check =
super::increment_builtin_counter(context, entry, location, entry.argument(0)?.into())?;
let gas_builtin = entry.argument(1)?.into();
let segment_arena =
super::increment_builtin_counter(context, entry, location, entry.argument(2)?.into())?;
let dict_ptr = entry.argument(3)?.into();

let runtime_bindings = metadata
.get_mut::<RuntimeBindingsMeta>()
.expect("Runtime library not available.");

let gas_refund = runtime_bindings
.dict_gas_refund(context, helper, entry, dict_ptr, location)?
.result(0)?
.into();
let gas_refund = entry
.append_operation(arith::extui(
gas_refund,
IntegerType::new(context, 128).into(),
location,
))
.result(0)?
.into();

let new_gas_builtin = entry
.append_operation(arith::addi(gas_builtin, gas_refund, location))
.result(0)?
.into();

entry.append_operation(helper.br(
0,
&[
range_check,
entry.argument(1)?.into(),
new_gas_builtin,
segment_arena,
entry.argument(3)?.into(),
],
Expand Down
48 changes: 48 additions & 0 deletions src/metadata/runtime_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum RuntimeBinding {
EcStateTryFinalizeNz,
DictNew,
DictGet,
DictGasRefund,
DictInsert,
DictFree,
}
Expand Down Expand Up @@ -657,6 +658,53 @@ impl RuntimeBindingsMeta {
location,
)))
}

/// Register if necessary, then invoke the `dict_gas_refund()` function.
///
/// Compute the total gas refund for the dictionary.
///
/// Returns a u64 of the result.
#[allow(clippy::too_many_arguments)]
pub fn dict_gas_refund<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
block: &'a Block<'c>,
dict_ptr: Value<'c, 'a>, // ptr to the dict
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
if self.active_map.insert(RuntimeBinding::DictGasRefund) {
module.body().append_operation(func::func(
context,
StringAttribute::new(context, "cairo_native__dict_gas_refund"),
TypeAttribute::new(
FunctionType::new(
context,
&[llvm::r#type::opaque_pointer(context)],
&[IntegerType::new(context, 64).into()],
)
.into(),
),
Region::new(),
&[(
Identifier::new(context, "sym_visibility"),
StringAttribute::new(context, "private").into(),
)],
Location::unknown(context),
));
}

Ok(block.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, "cairo_native__dict_gas_refund"),
&[dict_ptr],
&[IntegerType::new(context, 64).into()],
location,
)))
}
}

impl Default for RuntimeBindingsMeta {
Expand Down
5 changes: 3 additions & 2 deletions src/types/felt252_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
//!
//! A key value storage for values whose type implement Copy. The key is always a felt.
//!
//! This type is represented as a pointer to a heap allocated Rust hashmap, interacted through the runtime functions to
//! insert and get elements.
//! This type is represented as a pointer to a tuple of a heap allocated Rust hashmap along with a u64
//! used to count accesses to the dictionary. The type is interacted through the runtime functions to
//! insert, get elements and increment the access counter.

use super::WithSelf;
use crate::{error::Result, metadata::MetadataStorage};
Expand Down
1 change: 1 addition & 0 deletions src/types/gas_builtin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! # Gas builtin type
//!
//! The gas builtin is just a number indicating how many
//! gas units have been used.

use super::WithSelf;
use crate::{error::Result, metadata::MetadataStorage};
Expand Down
4 changes: 0 additions & 4 deletions src/types/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ fn snapshot_take<'ctx, 'this>(
metadata.insert(ReallocBindingsMeta::new(context, helper));
}

// let elem_snapshot_take = metadata
// .get::<SnapshotClonesMeta<TType, TLibfunc>>()
// .and_then(|meta| meta.wrap_invoke(&info.ty));

let elem_layout = registry.get_type(&info.ty)?.layout(registry)?;

let k0 = entry
Expand Down
7 changes: 7 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) {
NonNull<std::ffi::c_void>,
) -> *mut std::ffi::c_void as *mut (),
);

engine.register_symbol(
"cairo_native__dict_gas_refund",
cairo_native_runtime::cairo_native__dict_gas_refund
as *const fn(*const std::ffi::c_void, NonNull<std::ffi::c_void>) -> u64
as *mut (),
);
}
}

Expand Down
10 changes: 8 additions & 2 deletions tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ pub fn run_native_starknet_contract(
/// Given the result of the cairo-vm and cairo-native of the same program, it compares
/// the results automatically, triggering a proptest assert if there is a mismatch.
///
/// If ignore_gas is false, it will check whether the resulting gas matches.
///
/// Left of report of the assert is the cairo vm result, right side is cairo native
#[track_caller]
pub fn compare_outputs(
Expand Down Expand Up @@ -547,6 +545,14 @@ pub fn compare_outputs(
.map(|x| x.starts_with("core::panics::PanicResult"))
.unwrap_or(false);

assert_eq!(
vm_result
.gas_counter
.clone()
.unwrap_or_else(|| Felt252::from(0)),
Felt252::from(native_result.remaining_gas.unwrap_or(0)),
);

let vm_result = match &vm_result.value {
RunResultValue::Success(values) => {
if returns_panic {
Expand Down
6 changes: 3 additions & 3 deletions tests/tests/alexandria.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ fn compare_inputless_function(function_name: &str) {
#[test_case("aliquot_sum")]
#[test_case("extended_euclidean_algorithm")]
// alexandria_data_structures
#[test_case("vec" => ignore["Gas mismatch"])]
#[test_case("stack" => ignore["Gas mismatch"])]
#[test_case("vec")]
#[test_case("stack")]
#[test_case("queue")]
#[test_case("bit_array" => ignore["Gas mismatch"])]
#[test_case("bit_array")]
// alexandria_encoding
#[test_case("base64_encode" => ignore["Gas mismatch"])]
#[test_case("reverse_bits")]
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use test_case::test_case;
#[test_case("tests/cases/returns/simple.cairo")]
#[test_case("tests/cases/returns/tuple.cairo")]
// dict
#[test_case("tests/cases/dict/insert_get.cairo" => ignore["gas mismatch"])]
#[test_case("tests/cases/dict/insert_get.cairo")]
// uint
#[test_case("tests/cases/uint/compare.cairo")]
#[test_case("tests/cases/uint/consts.cairo")]
Expand Down
1 change: 0 additions & 1 deletion tests/tests/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ lazy_static! {

proptest! {
#[test]
#[ignore = "gas mismatch in dicts"]
fn dict_get_insert_proptest(a in any_felt(), b in any_felt()) {
let program = &DICT_GET_INSERT;
let result_vm = run_vm_program(
Expand Down

0 comments on commit bd23b1c

Please sign in to comment.