Skip to content

Commit

Permalink
adapt rust hints
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Feb 13, 2025
1 parent d787502 commit 54153b7
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm.git", tag = "v2.0.0-
] }

[patch."https://github.com/lambdaclass/cairo-vm.git"]
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "12c5b64ec4f1ee1fb8cfef4978267fb23f6bacce" }
cairo-vm = { git = "https://github.com/kkrt-labs/cairo-vm", rev = "354d2c5df4d68f68bf18b1bd1fab859f7a0056ef" }
5 changes: 4 additions & 1 deletion cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,9 @@ def default_factory():
]
)

all_preimages = {poseidon_hash_many(k) if len(k) > 1 else k: k for k in data.keys()}
all_preimages = {
poseidon_hash_many(k) if len(k) != 1 else k[0]: k for k in data.keys()
}

segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)
Expand All @@ -1199,6 +1201,7 @@ def default_factory():
default_value = (
data.default_factory() if isinstance(data, defaultdict) else None
)
dict_manager.preimages.update(all_preimages)
dict_manager.trackers[dict_ptr.segment_index] = RustDictTracker(
data=data,
current_ptr=current_ptr,
Expand Down
56 changes: 54 additions & 2 deletions crates/cairo-addons/src/vm/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ use cairo_vm::{
},
types::relocatable::MaybeRelocatable,
};
use pyo3::{prelude::*, types::PyTuple};
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
};
use std::{cell::RefCell, collections::HashMap, rc::Rc};

use super::{maybe_relocatable::PyMaybeRelocatable, relocatable::PyRelocatable};

#[derive(FromPyObject, Eq, PartialEq, Hash)]
#[derive(FromPyObject, Eq, PartialEq, Hash, Debug)]
pub enum PyDictKey {
#[pyo3(transparent)]
Simple(PyMaybeRelocatable),
Expand Down Expand Up @@ -82,6 +85,42 @@ impl PyTrackerMapping {
}
}

// Object returned by DictManager.preimages enabling access to the preimages by index and mutating
/// the preimages with manager.preimages[index] = preimage
#[pyclass(name = "PreimagesMapping", unsendable)]
pub struct PyPreimagesMapping {
inner: Rc<RefCell<RustDictManager>>,
}

#[pymethods]
impl PyPreimagesMapping {
fn __getitem__(&self, key: PyMaybeRelocatable) -> PyResult<PyDictKey> {
Ok(self.inner.borrow().preimages.get(&key.into()).cloned().unwrap().into())
}

fn __setitem__(&mut self, key: PyMaybeRelocatable, value: PyDictKey) -> PyResult<()> {
self.inner.borrow_mut().preimages.insert(key.into(), value.into());
Ok(())
}

fn update(&mut self, other: Bound<'_, PyDict>) -> PyResult<()> {
let other_dict = other.extract::<HashMap<PyMaybeRelocatable, PyDictKey>>()?;
self.inner
.borrow_mut()
.preimages
.extend(other_dict.into_iter().map(|(k, v)| (k.into(), v.into())));
Ok(())
}

fn __repr__(&self) -> PyResult<String> {
let inner = self.inner.borrow();
let mut pairs: Vec<_> = inner.preimages.iter().collect();
pairs.sort_by(|(k1, _), (k2, _)| k1.partial_cmp(k2).unwrap_or(std::cmp::Ordering::Equal));
let data_str =
pairs.into_iter().map(|(k, v)| format!("{}: {}", k, v)).collect::<Vec<_>>().join(", ");
Ok(format!("PreimagesMapping({{{}}})", data_str))
}
}
#[pyclass(name = "DictManager", unsendable)]
pub struct PyDictManager {
pub inner: Rc<RefCell<RustDictManager>>,
Expand All @@ -99,6 +138,19 @@ impl PyDictManager {
Ok(PyTrackerMapping { inner: self.inner.clone() })
}

#[getter]
fn get_preimages(&self) -> PyResult<PyPreimagesMapping> {
Ok(PyPreimagesMapping { inner: self.inner.clone() })
}

#[setter]
fn set_preimages(&mut self, value: Bound<'_, PyDict>) -> PyResult<()> {
let preimages = value.extract::<HashMap<PyMaybeRelocatable, PyDictKey>>()?;
self.inner.borrow_mut().preimages =
preimages.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
Ok(())
}

fn get_tracker(&self, ptr: PyRelocatable) -> PyResult<PyDictTracker> {
self.inner
.borrow()
Expand Down
6 changes: 6 additions & 0 deletions crates/cairo-addons/src/vm/felt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ impl From<Felt252> for PyFelt {
Self { inner: felt }
}
}

impl From<PyFelt> for Felt252 {
fn from(felt: PyFelt) -> Self {
felt.inner
}
}
102 changes: 60 additions & 42 deletions crates/cairo-addons/src/vm/hint_definitions/hashdict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ pub fn hashdict_read() -> Hint {
// Get dictionary pointer and setup tracker
let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict = dict_manager_ref.borrow_mut();
let tracker = dict.get_tracker_mut(dict_ptr)?;
let mut dict_manager = dict_manager_ref.borrow_mut();
let tracker = dict_manager.get_tracker_mut(dict_ptr)?;
tracker.current_ptr.offset += DICT_ACCESS_SIZE;

let key = get_ptr_from_var_name("key", vm, ids_data, ap_tracking)?;
Expand All @@ -65,7 +65,11 @@ pub fn hashdict_read() -> Hint {

tracker.get_value(&dict_key).and_then(|value| {
insert_value_from_var_name("value", value.clone(), vm, ids_data, ap_tracking)
})
})?;

let hashed_key = compute_hash_key(&dict_key, key_len);
dict_manager.preimages.insert(hashed_key.into(), dict_key);
Ok(())
},
)
}
Expand All @@ -82,8 +86,8 @@ pub fn hashdict_write() -> Hint {
// Get dictionary pointer and setup tracker
let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict = dict_manager_ref.borrow_mut();
let tracker = dict.get_tracker_mut(dict_ptr)?;
let mut dict_manager = dict_manager_ref.borrow_mut();
let tracker = dict_manager.get_tracker_mut(dict_ptr)?;
tracker.current_ptr.offset += DICT_ACCESS_SIZE;

let key = get_ptr_from_var_name("key", vm, ids_data, ap_tracking)?;
Expand All @@ -107,6 +111,8 @@ pub fn hashdict_write() -> Hint {
})?;
tracker.insert_value(&dict_key, &new_value);

let hashed_key = compute_hash_key(&dict_key, key_len);
dict_manager.preimages.insert(hashed_key.into(), dict_key);
Ok(())
},
)
Expand Down Expand Up @@ -217,14 +223,16 @@ pub fn hashdict_read_from_key() -> Hint {
// Get dictionary tracker
let dict_ptr = get_ptr_from_var_name("dict_ptr_stop", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict = dict_manager_ref.borrow_mut();
let tracker = dict.get_tracker_mut(dict_ptr)?;
let mut dict_manager = dict_manager_ref.borrow_mut();
let preimages = &dict_manager.preimages.clone();
let tracker = dict_manager.get_tracker_mut(dict_ptr)?;

// Find matching preimage and get its value. This hint can also be called on non-hashed
// keys.
let simple_key = DictKey::Simple(hashed_key.into());
let preimage =
_get_preimage_for_hashed_key(hashed_key, tracker).unwrap_or(&simple_key).clone();
let preimage = _get_preimage_for_hashed_key(hashed_key.into(), preimages)
.unwrap_or(&simple_key)
.clone();
let value = tracker
.get_value(&preimage)
.map_err(|_| {
Expand Down Expand Up @@ -253,13 +261,12 @@ pub fn get_preimage_for_key() -> Hint {
let hashed_key = get_integer_from_var_name("key", vm, ids_data, ap_tracking)?;

// Get dictionary tracker
let dict_ptr = get_ptr_from_var_name("dict_ptr_stop", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let dict = dict_manager_ref.borrow();
let tracker = dict.get_tracker(dict_ptr)?;
let dict_manager = dict_manager_ref.borrow();
let preimages = &dict_manager.preimages;

// Find matching preimage
let preimage = _get_preimage_for_hashed_key(hashed_key, tracker)?;
let preimage = _get_preimage_for_hashed_key(hashed_key.into(), preimages)?;

// Write preimage data to memory
let preimage_data_ptr =
Expand Down Expand Up @@ -297,13 +304,14 @@ pub fn copy_hashdict_tracker_entry() -> Hint {
get_ptr_from_var_name("source_ptr_stop", vm, ids_data, ap_tracking)?;
let dest_ptr = get_ptr_from_var_name("dest_ptr", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict = dict_manager_ref.borrow_mut();
let preimages = &dict_manager_ref.borrow().preimages;
let mut dict_manager = dict_manager_ref.borrow_mut();

let source_tracker = dict.get_tracker_mut(source_ptr_stop)?;
let source_tracker = dict_manager.get_tracker_mut(source_ptr_stop)?;

// Find matching preimage from source tracker data
let key_hash = get_integer_from_var_name("source_key", vm, ids_data, ap_tracking)?;
let preimage = _get_preimage_for_hashed_key(key_hash, source_tracker)?.clone();
let preimage = _get_preimage_for_hashed_key(key_hash.into(), preimages)?.clone();
let value = source_tracker
.get_value(&preimage)
.map_err(|_| {
Expand All @@ -314,7 +322,7 @@ pub fn copy_hashdict_tracker_entry() -> Hint {
.clone();

// Update destination tracker
let dest_tracker = dict.get_tracker_mut(dest_ptr)?;
let dest_tracker = dict_manager.get_tracker_mut(dest_ptr)?;
dest_tracker.current_ptr.offset += DICT_ACCESS_SIZE;
dest_tracker.insert_value(&preimage, &value.clone());

Expand All @@ -335,18 +343,24 @@ pub fn track_precompiles() -> Hint {
// Get dictionary pointer and setup tracker
let dict_ptr = get_ptr_from_var_name("dict_ptr", vm, ids_data, ap_tracking)?;
let dict_manager_ref = exec_scopes.get_dict_manager()?;
let mut dict = dict_manager_ref.borrow_mut();
let tracker = dict.get_tracker_mut(dict_ptr)?;
let mut dict_manager = dict_manager_ref.borrow_mut();
let tracker = dict_manager.get_tracker_mut(dict_ptr)?;

let precompiles = Precompiles::cancun().addresses().collect::<Vec<_>>();
for address in &precompiles {
let preimage =
vec![MaybeRelocatable::Int(Felt252::from_bytes_le_slice(&address.0 .0))];
tracker
.insert_value(&DictKey::Compound(preimage), &MaybeRelocatable::Int(1.into()));
tracker.current_ptr.offset += precompiles.len() * DICT_ACCESS_SIZE;

let mut preimage_entries = Vec::new();
for precompile_address in &precompiles {
let address_felt = Felt252::from_bytes_le_slice(&precompile_address.0 .0);
let preimage = vec![MaybeRelocatable::Int(address_felt)];
let dict_key = DictKey::Compound(preimage);
tracker.insert_value(&dict_key, &MaybeRelocatable::Int(1.into()));
preimage_entries.push((address_felt, dict_key));
}

tracker.current_ptr.offset += precompiles.len() * DICT_ACCESS_SIZE;
for (address_felt, dict_key) in preimage_entries {
dict_manager.preimages.insert(address_felt.into(), dict_key);
}

Ok(())
},
Expand All @@ -371,24 +385,28 @@ fn build_compound_key(

/// Helper function to find a preimage in a tracker's dictionary given a hashed key
fn _get_preimage_for_hashed_key(
hashed_key: Felt252,
tracker: &DictTracker,
hashed_key: MaybeRelocatable,
preimages: &HashMap<MaybeRelocatable, DictKey>,
) -> Result<&DictKey, HintError> {
tracker
.get_dictionary_ref()
.keys()
.find(|key| match key {
preimages.get(&hashed_key).ok_or_else(|| {
HintError::CustomHint(format!("No preimage found for hashed key {}", hashed_key).into())
})
}

/// Helper function to compute the hash key from a DictKey
fn compute_hash_key(dict_key: &DictKey, key_len: usize) -> Felt252 {
if key_len != 1 {
match dict_key {
DictKey::Compound(values) => {
let felt_values: Vec<Felt252> = values.iter().filter_map(|v| v.get_int()).collect();
if felt_values.len() == 1 {
felt_values[0] == hashed_key
} else {
poseidon_hash_many(felt_values.iter()) == hashed_key
}
let ints: Vec<Felt252> = values.iter().map(|v| v.get_int().unwrap()).collect();
poseidon_hash_many(&ints)
}
_ => false,
})
.ok_or_else(|| {
HintError::CustomHint(format!("No preimage found for hashed key {}", hashed_key).into())
})
DictKey::Simple(_) => panic!("Unreachable"),
}
} else {
match dict_key {
DictKey::Compound(values) => values[0].get_int().unwrap(),
DictKey::Simple(value) => value.get_int().unwrap(),
}
}
}
11 changes: 5 additions & 6 deletions python/cairo-addons/src/cairo_addons/hints/hashdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def hashdict_read(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict):
else:
ids.value = dict_tracker.data.default_factory()

# Register the preimage in a special sub-dict of the tracker.
hashed_key = poseidon_hash_many(preimage) if len(preimage) > 1 else preimage[0]
hashed_key = poseidon_hash_many(preimage) if len(preimage) != 1 else preimage[0]
dict_manager.preimages[hashed_key] = preimage


Expand All @@ -43,6 +42,8 @@ def hashdict_read_from_key(

@register_hint
def hashdict_write(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict):
from starkware.cairo.lang.vm.crypto import poseidon_hash_many

dict_tracker = dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
Expand All @@ -53,10 +54,7 @@ def hashdict_write(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict)
ids.dict_ptr.prev_value = dict_tracker.data.default_factory()
dict_tracker.data[preimage] = ids.new_value

# Register the preimage in a special sub-dict of the tracker.
from starkware.cairo.lang.vm.crypto import poseidon_hash_many

hashed_key = poseidon_hash_many(preimage) if len(preimage) > 1 else preimage[0]
hashed_key = poseidon_hash_many(preimage) if len(preimage) != 1 else preimage[0]
dict_manager.preimages[hashed_key] = preimage


Expand Down Expand Up @@ -124,5 +122,6 @@ def track_precompiles(
for key in PRE_COMPILED_CONTRACTS.keys():
preimage = (int.from_bytes(key, "little"),)
dict_tracker.data[preimage] = 1
dict_manager.preimages[preimage[0]] = preimage

dict_tracker.current_ptr += len(PRE_COMPILED_CONTRACTS) * ids.DictAccess.SIZE

0 comments on commit 54153b7

Please sign in to comment.