Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circular call logic and exception annotations #2028

Merged
merged 8 commits into from
Jun 24, 2020
45 changes: 45 additions & 0 deletions tests/functional/context/validation/test_cyclic_function_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.context.validation.module import ModuleNodeVisitor
from vyper.exceptions import CallViolation


def test_cyclic_function_call(namespace):
code = """
@private
def foo():
self.bar()

@private
def bar():
self.foo()
"""
vyper_module = parse_to_ast(code)
with namespace.enter_builtin_scope():
with pytest.raises(CallViolation):
ModuleNodeVisitor(vyper_module, {}, namespace)


def test_multi_cyclic_function_call(namespace):
code = """
@private
def foo():
self.bar()

@private
def bar():
self.baz()

@private
def baz():
self.potato()

@private
def potato():
self.foo()
"""
vyper_module = parse_to_ast(code)
with namespace.enter_builtin_scope():
with pytest.raises(CallViolation):
ModuleNodeVisitor(vyper_module, {}, namespace)
101 changes: 101 additions & 0 deletions tests/functional/context/validation/test_for_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.context.validation import validate_semantics
from vyper.exceptions import ConstancyViolation


def test_modify_iterator_function_outside_loop(namespace):
code = """

a: uint256[3]

@private
def foo():
self.a[0] = 1

@private
def bar():
self.foo()
for i in self.a:
pass
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, {})


def test_pass_memory_var_to_other_function(namespace):
code = """

@private
def foo(a: uint256[3]) -> uint256[3]:
b: uint256[3] = a
b[1] = 42
return b


@private
def bar():
a: uint256[3] = [1,2,3]
for i in a:
self.foo(a)
"""
vyper_module = parse_to_ast(code)
validate_semantics(vyper_module, {})


def test_modify_iterator(namespace):
code = """

a: uint256[3]

@private
def bar():
for i in self.a:
self.a[0] = 1
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ConstancyViolation):
validate_semantics(vyper_module, {})


def test_modify_iterator_function_call(namespace):
code = """

a: uint256[3]

@private
def foo():
self.a[0] = 1

@private
def bar():
for i in self.a:
self.foo()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ConstancyViolation):
validate_semantics(vyper_module, {})


def test_modify_iterator_recursive_function_call(namespace):
code = """

a: uint256[3]

@private
def foo():
self.a[0] = 1

@private
def bar():
self.foo()

@private
def baz():
for i in self.a:
self.bar()
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ConstancyViolation):
validate_semantics(vyper_module, {})
8 changes: 4 additions & 4 deletions tests/parser/exceptions/test_call_violation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def b():
p: int128 = self.a(10)
""",
"""
@public
def goo():
pass

@private
def foo():
self.goo()

@public
def goo():
self.foo()
""",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,9 @@ def test_invalid_external_contract_call_declaration_1(assert_compile_failed, get
contract_1 = """
contract Bar:
def bar() -> int128: pass

bar_contract: Bar

@public
def foo(contract_address: contract(Boo)) -> int128:
self.bar_contract = Bar(contract_address)
return self.bar_contract.bar()
"""

assert_compile_failed(lambda: get_contract(contract_1), UnknownType)
assert_compile_failed(lambda: get_contract(contract_1), StructureException)


def test_invalid_external_contract_call_declaration_2(assert_compile_failed, get_contract):
Expand Down
5 changes: 4 additions & 1 deletion vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def get_node(
ast_struct: Union[dict, python_ast.AST], parent: Optional[VyperNode] = ...
) -> VyperNode: ...

def compare_nodes(left_node: VyperNode, right_node: VyperNode) -> bool: ...

class VyperNode:
full_source_code: str = ...
def __init__(self, parent: Optional[VyperNode] = ..., **kwargs: dict) -> None: ...
Expand Down Expand Up @@ -110,7 +112,8 @@ class NameConstant(Constant): ...
class Name(VyperNode):
id: str = ...

class Expr(VyperNode): ...
class Expr(VyperNode):
value: VyperNode = ...

class UnaryOp(VyperNode):
op: VyperNode = ...
Expand Down
8 changes: 6 additions & 2 deletions vyper/context/types/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from vyper.context.types.bases import DataLocation
from vyper.context.types.utils import get_type_from_annotation
from vyper.context.validation.utils import validate_expected_type
from vyper.exceptions import StructureException
from vyper.exceptions import EventDeclarationException, StructureException

# NOTE: This implementation isn't as polished as it could be, because it will be
# replaced with a new struct-style syntax prior to the next release.
Expand Down Expand Up @@ -41,13 +41,17 @@ def from_annotation(
is_public: bool = False,
) -> "Event":
arguments = OrderedDict()
indexed = []
indexed: List = []
validate_call_args(node, 1)
if not isinstance(node.args[0], vy_ast.Dict):
raise StructureException("Invalid event declaration syntax", node.args[0])
for key, value in zip(node.args[0].keys, node.args[0].values):
if isinstance(value, vy_ast.Call) and value.get("func.id") == "indexed":
validate_call_args(value, 1)
if indexed.count(True) == 3:
raise EventDeclarationException(
"Event cannot have more than three indexed arguments", value
)
fubuloubu marked this conversation as resolved.
Show resolved Hide resolved
indexed.append(True)
value = value.args[0]
else:
Expand Down
3 changes: 3 additions & 0 deletions vyper/context/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionType":
def from_FunctionDef(
cls,
node: vy_ast.FunctionDef,
is_constant: Optional[bool] = None,
is_public: Optional[bool] = None,
include_defaults: Optional[bool] = True,
) -> "ContractFunctionType":
Expand All @@ -142,6 +143,8 @@ def from_FunctionDef(
ContractFunctionType
"""
kwargs: Dict[str, Any] = {}
if is_constant is not None:
kwargs["is_constant"] = is_constant
if is_public is not None:
kwargs["is_public"] = is_public

Expand Down
2 changes: 1 addition & 1 deletion vyper/context/types/indexable/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class TupleDefinition(_SequenceDefinition):
def __init__(self, value_type: Tuple[BaseTypeDefinition, ...]) -> None:
# always use the most restrictive location re: modification
location = sorted((i.location for i in value_type), key=lambda k: k.value)[-1]
is_constant = next((True for i in value_type if getattr(i, 'is_constant', None)), False)
is_constant = next((True for i in value_type if getattr(i, "is_constant", None)), False)
super().__init__(
value_type, # type: ignore
len(value_type),
Expand Down
12 changes: 11 additions & 1 deletion vyper/context/types/meta/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,15 @@ def _get_class_functions(base_node: vy_ast.ClassDef) -> OrderedDict:
for node in base_node.body:
if not isinstance(node, vy_ast.FunctionDef):
raise StructureException("Interfaces can only contain function definitions", node)
functions[node.name] = ContractFunctionType.from_FunctionDef(node, is_public=True)

if len(node.body) != 1 or node.body[0].get("value.id") not in ("constant", "modifying"):
raise StructureException(
"Interface function must be set as constant or modifying",
node.body[0] if node.body else node,
)

is_constant = bool(node.body[0].value.id == "constant")
fn = ContractFunctionType.from_FunctionDef(node, is_constant=is_constant, is_public=True)
functions[node.name] = fn

return functions
70 changes: 55 additions & 15 deletions vyper/context/validation/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from typing import Optional

from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
Expand Down Expand Up @@ -35,7 +36,7 @@
)


def validate_functions(vy_module):
def validate_functions(vy_module: vy_ast.Module) -> None:

"""Analyzes a vyper ast and validates the function-level namespaces."""

Expand All @@ -44,14 +45,14 @@ def validate_functions(vy_module):
for node in vy_module.get_children(vy_ast.FunctionDef):
with namespace.enter_scope():
try:
FunctionNodeVisitor(node, namespace)
FunctionNodeVisitor(vy_module, node, namespace)
except VyperException as e:
err_list.append(e)

err_list.raise_if_not_empty()


def _is_terminus_node(node):
def _is_terminus_node(node: vy_ast.VyperNode) -> bool:
if getattr(node, "_is_terminus", None):
return True
if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call):
Expand All @@ -73,6 +74,24 @@ def check_for_terminus(node_list: list) -> bool:
return False


def _check_iterator_assign(
target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode
) -> Optional[vy_ast.VyperNode]:
similar_nodes = [
n
for n in search_node.get_descendants(type(target_node))
if vy_ast.compare_nodes(target_node, n)
]

for node in similar_nodes:
# raise if the node is the target of an assignment statement
assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign))
if assign_node and node in assign_node.target.get_descendants(include_self=True):
return node

return None


class FunctionNodeVisitor(VyperNodeVisitorBase):

ignored_types = (
Expand All @@ -83,7 +102,10 @@ class FunctionNodeVisitor(VyperNodeVisitorBase):
)
scope_name = "function"

def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None:
def __init__(
self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict
) -> None:
self.vyper_module = vyper_module
self.fn_node = fn_node
self.namespace = namespace
self.func = namespace["self"].get_member(fn_node.name, fn_node)
Expand Down Expand Up @@ -269,17 +291,35 @@ def visit_For(self, node):
raise StructureException("Cannot iterate over a nested list", node.iter)

if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
# find references to the iterated node within the for-loop body
similar_nodes = [
n
for n in node.get_descendants(type(node.iter))
if vy_ast.compare_nodes(node.iter, n)
]
for n in similar_nodes:
# raise if the node is the target of an assignment statement
assign = n.get_ancestor((vy_ast.Assign, vy_ast.AugAssign))
if assign and n in assign.target.get_descendants(include_self=True):
raise ConstancyViolation("Cannot alter array during iteration", n)
# check for references to the iterated value within the body of the loop
assign = _check_iterator_assign(node.iter, node)
if assign:
raise ConstancyViolation("Cannot modify array during iteration", assign)

if node.iter.get("value.id") == "self":
# check if iterated value may be modified by function calls inside the loop
iter_name = node.iter.attr
for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}):
fn_name = call_node.func.attr

fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0]
if _check_iterator_assign(node.iter, fn_node):
# check for direct modification
raise ConstancyViolation(
f"Cannot call '{fn_name}' inside for loop, it potentially "
f"modifies iterated storage variable '{iter_name}'",
call_node,
)

for name in self.namespace["self"].members[fn_name].recursive_calls:
# check for indirect modification
fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0]
if _check_iterator_assign(node.iter, fn_node):
raise ConstancyViolation(
f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' "
f"which potentially modifies iterated storage variable '{iter_name}'",
call_node,
)

for type_ in type_list:
type_ = copy.deepcopy(type_)
Expand Down
Loading