Skip to content

Commit

Permalink
fix: call graph stability (#3370)
Browse files Browse the repository at this point in the history
the use of `set` (which does not guarantee order of its elements)
instead of `dict` (which guarantees insertion order) led to instability
across runs of the compiler between different versions of python. this
commit patches the problem by using a derivation of `dict` to track the
call graph instead of `set`.
  • Loading branch information
charles-cooper authored May 5, 2023
1 parent f914011 commit 5ae0a07
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
68 changes: 68 additions & 0 deletions tests/parser/test_call_graph_stability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import random
import string

import hypothesis.strategies as st
import pytest
from hypothesis import given, settings

import vyper.ast as vy_ast
from vyper.compiler.phases import CompilerData


# random names for functions
@settings(max_examples=20, deadline=None)
@given(
st.lists(
st.tuples(
st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
st.text(alphabet=string.ascii_lowercase, min_size=1),
),
unique_by=lambda x: x[1], # unique on function name
min_size=1,
max_size=10,
)
)
@pytest.mark.fuzzing
def test_call_graph_stability_fuzz(funcs):
def generate_func_def(mutability, func_name, i):
return f"""
@internal
{mutability}
def {func_name}() -> uint256:
return {i}
"""

func_defs = "\n".join(generate_func_def(m, s, i) for i, (m, s) in enumerate(funcs))

for _ in range(10):
func_names = [f for (_, f) in funcs]
random.shuffle(func_names)

self_calls = "\n".join(f" self.{f}()" for f in func_names)
code = f"""
{func_defs}
@external
def foo():
{self_calls}
"""
t = CompilerData(code)

# check the .called_functions data structure on foo() directly
foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0]
foo_t = foo._metadata["type"]
assert [f.name for f in foo_t.called_functions] == func_names

# now for sanity, ensure the order that the function definitions appear
# in the IR is the same as the order of the calls
sigs = t.function_signatures
del sigs["foo"]
ir = t.ir_runtime
ir_funcs = []
# search for function labels
for d in ir.args: # currently: (seq ... (seq (label foo ...)) ...)
if d.value == "seq" and d.args[0].value == "label":
r = d.args[0].args[0].value
if isinstance(r, str) and r.startswith("internal"):
ir_funcs.append(r)
assert ir_funcs == [f.internal_function_label for f in sigs.values()]
6 changes: 3 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Tuple

from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
Expand All @@ -28,7 +28,7 @@
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.subscriptable import TupleT
from vyper.semantics.types.utils import type_from_abi, type_from_annotation
from vyper.utils import keccak256
from vyper.utils import OrderedSet, keccak256


class ContractFunctionT(VyperType):
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
self.nonreentrant = nonreentrant

# a list of internal functions this function calls
self.called_functions: Set["ContractFunctionT"] = set()
self.called_functions = OrderedSet()

# special kwargs that are allowed in call site
self.call_site_kwargs = {
Expand Down
13 changes: 13 additions & 0 deletions vyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
from vyper.exceptions import DecimalOverrideException, InvalidLiteral


class OrderedSet(dict):
"""
a minimal "ordered set" class. this is needed in some places
because, while dict guarantees you can recover insertion order
vanilla sets do not.
no attempt is made to fully implement the set API, will add
functionality as needed.
"""

def add(self, item):
self[item] = None


class DecimalContextOverride(decimal.Context):
def __setattr__(self, name, value):
if name == "prec":
Expand Down

0 comments on commit 5ae0a07

Please sign in to comment.