Skip to content

Commit

Permalink
fix preimage-related issues
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Feb 12, 2025
1 parent 477265c commit 1d14a4c
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 15 deletions.
40 changes: 40 additions & 0 deletions cairo/ethereum/cancun/trie.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,13 @@ func _prepare_trie_inner_storage{
return mapping_ptr_end;
}

// Skip all None values, which are deleted trie entries
if (dict_ptr.new_value.value == 0) {
return _prepare_trie_inner_storage(
trie, dict_ptr + Bytes32U256DictAccess.SIZE, mapping_ptr_end
);
}

let preimage_b32 = _get_bytes32_preimage_for_key(
dict_ptr.key.value, cast(trie.value._data.value.dict_ptr, DictAccess*)
);
Expand Down Expand Up @@ -1672,102 +1679,135 @@ func _get_branches{poseidon_ptr: PoseidonBuiltin*}(obj: MappingBytesBytes, level
alloc_locals;

let (local branches: MappingBytesBytes*) = alloc();
let src_dict_end = obj.value.dict_ptr;

local value: Bytes;
local value_set: felt;

let (branches_0, value_0) = _get_branch_for_nibble_at_level(obj, 0, level.value);
assert branches[0] = branches_0;
let dst_dict_end = branches_0.value.dict_ptr;
%{ copy_preimages %}
if (value_0.value.len != 0) {
assert value = value_0;
assert value_set = 1;
}
let (branches_1, value_1) = _get_branch_for_nibble_at_level(obj, 1, level.value);
assert branches[1] = branches_1;
let dst_dict_end = branches_1.value.dict_ptr;
%{ copy_preimages %}
if (value_1.value.len != 0) {
assert value = value_1;
assert value_set = 1;
}
let (branches_2, value_2) = _get_branch_for_nibble_at_level(obj, 2, level.value);
assert branches[2] = branches_2;
let dst_dict_end = branches_2.value.dict_ptr;
%{ copy_preimages %}
if (value_2.value.len != 0) {
assert value = value_2;
assert value_set = 1;
}
let (branches_3, value_3) = _get_branch_for_nibble_at_level(obj, 3, level.value);
assert branches[3] = branches_3;
let dst_dict_end = branches_3.value.dict_ptr;
%{ copy_preimages %}
if (value_3.value.len != 0) {
assert value = value_3;
assert value_set = 1;
}
let (branches_4, value_4) = _get_branch_for_nibble_at_level(obj, 4, level.value);
assert branches[4] = branches_4;
let dst_dict_end = branches_4.value.dict_ptr;
%{ copy_preimages %}
if (value_4.value.len != 0) {
assert value = value_4;
assert value_set = 1;
}
let (branches_5, value_5) = _get_branch_for_nibble_at_level(obj, 5, level.value);
assert branches[5] = branches_5;
let dst_dict_end = branches_5.value.dict_ptr;
%{ copy_preimages %}
if (value_5.value.len != 0) {
assert value = value_5;
assert value_set = 1;
}
let (branches_6, value_6) = _get_branch_for_nibble_at_level(obj, 6, level.value);
assert branches[6] = branches_6;
let dst_dict_end = branches_6.value.dict_ptr;
%{ copy_preimages %}
if (value_6.value.len != 0) {
assert value = value_6;
assert value_set = 1;
}
let (branches_7, value_7) = _get_branch_for_nibble_at_level(obj, 7, level.value);
assert branches[7] = branches_7;
let dst_dict_end = branches_7.value.dict_ptr;
%{ copy_preimages %}
if (value_7.value.len != 0) {
assert value = value_7;
assert value_set = 1;
}
let (branches_8, value_8) = _get_branch_for_nibble_at_level(obj, 8, level.value);
assert branches[8] = branches_8;
let dst_dict_end = branches_8.value.dict_ptr;
%{ copy_preimages %}
if (value_8.value.len != 0) {
assert value = value_8;
assert value_set = 1;
}
let (branches_9, value_9) = _get_branch_for_nibble_at_level(obj, 9, level.value);
assert branches[9] = branches_9;
let dst_dict_end = branches_9.value.dict_ptr;
%{ copy_preimages %}
if (value_9.value.len != 0) {
assert value = value_9;
assert value_set = 1;
}
let (branches_10, value_10) = _get_branch_for_nibble_at_level(obj, 10, level.value);
assert branches[10] = branches_10;
let dst_dict_end = branches_10.value.dict_ptr;
%{ copy_preimages %}
if (value_10.value.len != 0) {
assert value = value_10;
assert value_set = 1;
}
let (branches_11, value_11) = _get_branch_for_nibble_at_level(obj, 11, level.value);
assert branches[11] = branches_11;
let dst_dict_end = branches_11.value.dict_ptr;
%{ copy_preimages %}
if (value_11.value.len != 0) {
assert value = value_11;
assert value_set = 1;
}
let (branches_12, value_12) = _get_branch_for_nibble_at_level(obj, 12, level.value);
assert branches[12] = branches_12;
let dst_dict_end = branches_12.value.dict_ptr;
%{ copy_preimages %}
if (value_12.value.len != 0) {
assert value = value_12;
assert value_set = 1;
}
let (branches_13, value_13) = _get_branch_for_nibble_at_level(obj, 13, level.value);
assert branches[13] = branches_13;
let dst_dict_end = branches_13.value.dict_ptr;
%{ copy_preimages %}
if (value_13.value.len != 0) {
assert value = value_13;
assert value_set = 1;
}
let (branches_14, value_14) = _get_branch_for_nibble_at_level(obj, 14, level.value);
assert branches[14] = branches_14;
let dst_dict_end = branches_14.value.dict_ptr;
%{ copy_preimages %}
if (value_14.value.len != 0) {
assert value = value_14;
assert value_set = 1;
}
let (branches_15, value_15) = _get_branch_for_nibble_at_level(obj, 15, level.value);
assert branches[15] = branches_15;
let dst_dict_end = branches_15.value.dict_ptr;
%{ copy_preimages %}
if (value_15.value.len != 0) {
assert value = value_15;
assert value_set = 1;
Expand Down
3 changes: 3 additions & 0 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,9 @@ def default_factory():
]
)

all_preimages = {poseidon_hash_many(k) if len(k) > 1 else k: k for k in data.keys()}
data["preimages"] = all_preimages

segments.load_data(dict_ptr, initial_data)
current_ptr = dict_ptr + len(initial_data)

Expand Down
4 changes: 4 additions & 0 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,8 @@ def _serialize_mapping_struct(
tracker_data = self.dict_manager.trackers[dict_ptr.segment_index].data
if isinstance(cairo_key_type, TypeFelt):
for key, value in tracker_data.items():
if key == "preimages":
continue
# We skip serialization of null pointers, but serialize values equal to zero
if value == 0 and self.is_pointer_wrapper(value_type.scope.path):
continue
Expand Down Expand Up @@ -620,6 +622,8 @@ def key_transform(k):
return python_key_type(k)

for cairo_key, cairo_value in tracker_data.items():
if cairo_key == "preimages":
continue
preimage = key_transform(cairo_key)

# For pointer types, a value of 0 means absent - should skip
Expand Down
10 changes: 6 additions & 4 deletions python/cairo-addons/src/cairo_addons/hints/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ def copy_dict_segment(
# Same as new_dict but supports a default value
base = segments.add()
assert base.segment_index not in dict_manager.trackers
copied_data = {
key: segments.gen_arg(value) for key, value in current_tracker.data.items()
}
# manually keep the "preimages" key in the dict
copied_data["preimages"] = current_tracker.data["preimages"]
dict_manager.trackers[base.segment_index] = DictTracker(
data=defaultdict(
current_tracker.data.default_factory,
{
key: segments.gen_arg(value)
for key, value in current_tracker.data.items()
},
copied_data,
),
current_ptr=base,
)
Expand Down
39 changes: 28 additions & 11 deletions python/cairo-addons/src/cairo_addons/hints/hashdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

@register_hint
def hashdict_read(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict):
from starkware.cairo.lang.vm.crypto import poseidon_hash_many

dict_tracker = dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
Expand All @@ -18,6 +20,12 @@ def hashdict_read(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict):
else:
ids.value = dict_tracker.data.default_factory()

# Register the preimage in a special sub-dict of the tracker.
logged_key = poseidon_hash_many(preimage) if len(preimage) > 1 else preimage[0]
if not isinstance(dict_tracker.data.get("preimages"), dict):
dict_tracker.data["preimages"] = {}
dict_tracker.data["preimages"][logged_key] = preimage


@register_hint
def hashdict_read_from_key(
Expand Down Expand Up @@ -47,6 +55,14 @@ def hashdict_write(dict_manager: DictManager, ids: VmConsts, memory: MemoryDict)
ids.dict_ptr.prev_value = dict_tracker.data.default_factory()
dict_tracker.data[preimage] = ids.new_value

# Register the preimage in a special sub-dict of the tracker.
from starkware.cairo.lang.vm.crypto import poseidon_hash_many

logged_key = poseidon_hash_many(preimage) if len(preimage) > 1 else preimage[0]
if not isinstance(dict_tracker.data.get("preimages"), dict):
dict_tracker.data["preimages"] = {}
dict_tracker.data["preimages"][logged_key] = preimage


@register_hint
def get_keys_for_address_prefix(
Expand Down Expand Up @@ -95,22 +111,23 @@ def copy_hashdict_tracker_entry(dict_manager: DictManager, ids: VmConsts):
dict_tracker.data[preimage] = obj_tracker.data[preimage]


@register_hint
def copy_preimages(dict_manager: DictManager, ids: VmConsts):
src_tracker = dict_manager.get_tracker(ids.src_dict_end.address_)
dst_tracker = dict_manager.get_tracker(ids.dst_dict_end.address_)
dst_tracker.data["preimages"] = src_tracker.data["preimages"].copy()


def _get_preimage_for_hashed_key(
hashed_key: int,
dict_tracker: DictTracker,
) -> tuple:
from starkware.cairo.lang.vm.crypto import poseidon_hash_many

# Get the key in the dict that matches the hashed value
preimage = next(
key
for key in dict_tracker.data.keys()
if (
key[0] == hashed_key
if len(key) == 1
else poseidon_hash_many(key) == hashed_key
)
)
if not isinstance(dict_tracker.data.get("preimages"), dict):
raise Exception("No preimages found")
if hashed_key not in dict_tracker.data["preimages"]:
raise Exception("No preimage found for hashed key")
preimage = dict_tracker.data["preimages"][hashed_key]
return preimage


Expand Down

0 comments on commit 1d14a4c

Please sign in to comment.