Skip to content

Commit

Permalink
feat: use DictManager's preimage field to track preimages of hashed…
Browse files Browse the repository at this point in the history
… keys (#788)

Closes #795

---------

Co-authored-by: Elias Tazartes <66871571+Eikix@users.noreply.github.com>
  • Loading branch information
enitrat and Eikix authored Feb 13, 2025
1 parent 5d4d479 commit a392270
Show file tree
Hide file tree
Showing 15 changed files with 189 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm.git", tag = "v2.0.0-
] }

[patch."https://github.com/lambdaclass/cairo-vm.git"]
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "12c5b64ec4f1ee1fb8cfef4978267fb23f6bacce" }
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "7e0e13708d2265fadc913692709737b50908b97c" }
4 changes: 3 additions & 1 deletion cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from ethereum_types.bytes import Bytes, Bytes32
from ethereum_types.numeric import U256, Uint
from hypothesis import example, given
from hypothesis import Verbosity, example, given, settings
from hypothesis import strategies as st

from cairo_addons.testing.errors import cairo_error, strict_raises
Expand Down Expand Up @@ -79,6 +79,7 @@ def test_common_prefix_length(self, cairo_run, a: Bytes, b: Bytes):
assert common_prefix_length(a, b) == cairo_run("common_prefix_length", a, b)

@given(a=..., b=...)
@settings(verbosity=Verbosity.quiet)
def test_common_prefix_length_should_fail(
self, cairo_programs, cairo_run_py, a: Bytes, b: Bytes
):
Expand All @@ -101,6 +102,7 @@ def test_nibble_list_to_compact(self, cairo_run, x, is_leaf: bool):
)

@given(x=nibble.filter(lambda x: len(x) != 0), is_leaf=...)
@settings(verbosity=Verbosity.quiet)
def test_nibble_list_to_compact_should_raise_when_wrong_remainder(
self, cairo_programs, cairo_run_py, x, is_leaf: bool
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
ECRECOVER_ADDRESS,
PRE_COMPILED_CONTRACTS,
)
from hypothesis import example, given
from hypothesis import Verbosity, example, given, settings
from hypothesis import strategies as st

from cairo_addons.testing.errors import cairo_error
Expand All @@ -27,6 +27,7 @@ def test_precompile_table_lookup_invalid_addresses(self, cairo_run, address_int)
assert table_address == 0

@given(address=st.sampled_from(list(PRE_COMPILED_CONTRACTS.keys())))
@settings(verbosity=Verbosity.quiet)
def test_precompile_table_lookup_hint_index_out_of_bounds(
self, cairo_programs, cairo_run_py, address
):
Expand All @@ -42,6 +43,7 @@ def test_precompile_table_lookup_hint_index_out_of_bounds(
cairo_run_py("precompile_table_lookup", address_int)

@given(address=st.sampled_from(list(PRE_COMPILED_CONTRACTS.keys())))
@settings(verbosity=Verbosity.quiet)
def test_precompile_table_lookup_hint_index_different_address(
self, cairo_programs, cairo_run_py, address
):
Expand Down
4 changes: 3 additions & 1 deletion cairo/tests/legacy/utils/test_bytes_legacy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from hypothesis import given
from hypothesis import Verbosity, given, settings
from hypothesis.strategies import binary, integers
from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME

Expand Down Expand Up @@ -39,6 +39,7 @@ def test_should_raise_when_value_sup_31_bytes(self, cairo_run, n):
# This test checks the function fails if the % base is removed from the hint
# All values up to 256 will have the same decomposition if the it is removed
@given(n=integers(min_value=256, max_value=2**248 - 1))
@settings(verbosity=Verbosity.quiet)
def test_should_raise_when_byte_value_not_modulo_base(
self, cairo_programs, cairo_run, n
):
Expand All @@ -63,6 +64,7 @@ def test_should_raise_when_byte_value_not_modulo_base(
!= 0
)
)
@settings(verbosity=Verbosity.quiet)
def test_should_raise_when_bytes_len_is_not_minimal(
self, cairo_programs, cairo_run, n
):
Expand Down
9 changes: 9 additions & 0 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,17 +1182,26 @@ def default_factory():
]
)

all_preimages = {
poseidon_hash_many(k) if len(k) != 1 else k[0]: k for k in data.keys()
}

segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)

if isinstance(dict_manager, DictManager):
dict_manager.trackers[dict_ptr.segment_index] = DictTracker(
data=data, current_ptr=current_ptr
)
# Set a new field in the dict_manager to store all preimages.
if not hasattr(dict_manager, "preimages"):
dict_manager.preimages = {}
dict_manager.preimages.update(all_preimages)
else:
default_value = (
data.default_factory() if isinstance(data, defaultdict) else None
)
dict_manager.preimages.update(all_preimages)
dict_manager.trackers[dict_ptr.segment_index] = RustDictTracker(
data=data,
current_ptr=current_ptr,
Expand Down
65 changes: 63 additions & 2 deletions crates/cairo-addons/src/vm/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ use cairo_vm::{
},
types::relocatable::MaybeRelocatable,
};
use pyo3::{prelude::*, types::PyTuple};
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
};
use std::{cell::RefCell, collections::HashMap, rc::Rc};

use super::{maybe_relocatable::PyMaybeRelocatable, relocatable::PyRelocatable};

#[derive(FromPyObject, Eq, PartialEq, Hash)]
#[derive(FromPyObject, Eq, PartialEq, Hash, Debug)]
pub enum PyDictKey {
#[pyo3(transparent)]
Simple(PyMaybeRelocatable),
Expand Down Expand Up @@ -82,6 +85,51 @@ impl PyTrackerMapping {
}
}

// Object returned by DictManager.preimages enabling access to the preimages by index and mutating
/// the preimages with manager.preimages[index] = preimage
#[pyclass(name = "PreimagesMapping", unsendable)]
pub struct PyPreimagesMapping {
inner: Rc<RefCell<RustDictManager>>,
}

#[pymethods]
impl PyPreimagesMapping {
fn __getitem__(&self, key: PyMaybeRelocatable) -> PyResult<PyDictKey> {
Ok(self
.inner
.borrow()
.preimages
.get(&key.clone().into())
.cloned()
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("key {:?} not found", key))
})?
.into())
}

fn __setitem__(&mut self, key: PyMaybeRelocatable, value: PyDictKey) -> PyResult<()> {
self.inner.borrow_mut().preimages.insert(key.into(), value.into());
Ok(())
}

fn update(&mut self, other: Bound<'_, PyDict>) -> PyResult<()> {
let other_dict = other.extract::<HashMap<PyMaybeRelocatable, PyDictKey>>()?;
self.inner
.borrow_mut()
.preimages
.extend(other_dict.into_iter().map(|(k, v)| (k.into(), v.into())));
Ok(())
}

fn __repr__(&self) -> PyResult<String> {
let inner = self.inner.borrow();
let mut pairs: Vec<_> = inner.preimages.iter().collect();
pairs.sort_by(|(k1, _), (k2, _)| k1.partial_cmp(k2).unwrap_or(std::cmp::Ordering::Equal));
let data_str =
pairs.into_iter().map(|(k, v)| format!("{}: {}", k, v)).collect::<Vec<_>>().join(", ");
Ok(format!("PreimagesMapping({{{}}})", data_str))
}
}
#[pyclass(name = "DictManager", unsendable)]
pub struct PyDictManager {
pub inner: Rc<RefCell<RustDictManager>>,
Expand All @@ -99,6 +147,19 @@ impl PyDictManager {
Ok(PyTrackerMapping { inner: self.inner.clone() })
}

#[getter]
fn get_preimages(&self) -> PyResult<PyPreimagesMapping> {
Ok(PyPreimagesMapping { inner: self.inner.clone() })
}

#[setter]
fn set_preimages(&mut self, value: Bound<'_, PyDict>) -> PyResult<()> {
let preimages = value.extract::<HashMap<PyMaybeRelocatable, PyDictKey>>()?;
self.inner.borrow_mut().preimages =
preimages.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
Ok(())
}

fn get_tracker(&self, ptr: PyRelocatable) -> PyResult<PyDictTracker> {
self.inner
.borrow()
Expand Down
6 changes: 6 additions & 0 deletions crates/cairo-addons/src/vm/felt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ impl From<Felt252> for PyFelt {
Self { inner: felt }
}
}

impl From<PyFelt> for Felt252 {
fn from(felt: PyFelt) -> Self {
felt.inner
}
}
Loading

0 comments on commit a392270

Please sign in to comment.