diff --git a/tests/functional/context/validation/test_cyclic_function_calls.py b/tests/functional/context/validation/test_cyclic_function_calls.py new file mode 100644 index 0000000000..a79f2bc62e --- /dev/null +++ b/tests/functional/context/validation/test_cyclic_function_calls.py @@ -0,0 +1,45 @@ +import pytest + +from vyper.ast import parse_to_ast +from vyper.context.validation.module import ModuleNodeVisitor +from vyper.exceptions import CallViolation + + +def test_cyclic_function_call(namespace): + code = """ +@private +def foo(): + self.bar() + +@private +def bar(): + self.foo() + """ + vyper_module = parse_to_ast(code) + with namespace.enter_builtin_scope(): + with pytest.raises(CallViolation): + ModuleNodeVisitor(vyper_module, {}, namespace) + + +def test_multi_cyclic_function_call(namespace): + code = """ +@private +def foo(): + self.bar() + +@private +def bar(): + self.baz() + +@private +def baz(): + self.potato() + +@private +def potato(): + self.foo() + """ + vyper_module = parse_to_ast(code) + with namespace.enter_builtin_scope(): + with pytest.raises(CallViolation): + ModuleNodeVisitor(vyper_module, {}, namespace) diff --git a/tests/functional/context/validation/test_for_loop.py b/tests/functional/context/validation/test_for_loop.py new file mode 100644 index 0000000000..d10f22c59f --- /dev/null +++ b/tests/functional/context/validation/test_for_loop.py @@ -0,0 +1,101 @@ +import pytest + +from vyper.ast import parse_to_ast +from vyper.context.validation import validate_semantics +from vyper.exceptions import ConstancyViolation + + +def test_modify_iterator_function_outside_loop(namespace): + code = """ + +a: uint256[3] + +@private +def foo(): + self.a[0] = 1 + +@private +def bar(): + self.foo() + for i in self.a: + pass + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, {}) + + +def test_pass_memory_var_to_other_function(namespace): + code = """ + +@private +def foo(a: uint256[3]) -> uint256[3]: + b: uint256[3] = a + b[1] = 42 + return b + + +@private +def bar(): + a: uint256[3] = [1,2,3] + for i in a: + self.foo(a) + """ + vyper_module = parse_to_ast(code) + validate_semantics(vyper_module, {}) + + +def test_modify_iterator(namespace): + code = """ + +a: uint256[3] + +@private +def bar(): + for i in self.a: + self.a[0] = 1 + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ConstancyViolation): + validate_semantics(vyper_module, {}) + + +def test_modify_iterator_function_call(namespace): + code = """ + +a: uint256[3] + +@private +def foo(): + self.a[0] = 1 + +@private +def bar(): + for i in self.a: + self.foo() + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ConstancyViolation): + validate_semantics(vyper_module, {}) + + +def test_modify_iterator_recursive_function_call(namespace): + code = """ + +a: uint256[3] + +@private +def foo(): + self.a[0] = 1 + +@private +def bar(): + self.foo() + +@private +def baz(): + for i in self.a: + self.bar() + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ConstancyViolation): + validate_semantics(vyper_module, {}) diff --git a/tests/parser/exceptions/test_call_violation.py b/tests/parser/exceptions/test_call_violation.py index b50fb87205..7d2a7073f7 100644 --- a/tests/parser/exceptions/test_call_violation.py +++ b/tests/parser/exceptions/test_call_violation.py @@ -19,13 +19,13 @@ def b(): p: int128 = self.a(10) """, """ +@public +def goo(): + pass + @private def foo(): self.goo() - -@public -def goo(): - self.foo() """, ] diff --git a/tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py b/tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py index 3cab606b22..0d3cc0ce90 100644 --- a/tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py +++ b/tests/parser/features/external_contracts/test_modifiable_external_contract_calls.py @@ -198,16 +198,9 @@ def test_invalid_external_contract_call_declaration_1(assert_compile_failed, get contract_1 = """ contract Bar: def bar() -> int128: pass - -bar_contract: Bar - -@public -def foo(contract_address: contract(Boo)) -> int128: - self.bar_contract = Bar(contract_address) - return self.bar_contract.bar() """ - assert_compile_failed(lambda: get_contract(contract_1), UnknownType) + assert_compile_failed(lambda: get_contract(contract_1), StructureException) def test_invalid_external_contract_call_declaration_2(assert_compile_failed, get_contract): diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 7272e7558a..73a628c54b 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -13,6 +13,8 @@ def get_node( ast_struct: Union[dict, python_ast.AST], parent: Optional[VyperNode] = ... ) -> VyperNode: ... +def compare_nodes(left_node: VyperNode, right_node: VyperNode) -> bool: ... + class VyperNode: full_source_code: str = ... def __init__(self, parent: Optional[VyperNode] = ..., **kwargs: dict) -> None: ... @@ -110,7 +112,8 @@ class NameConstant(Constant): ... class Name(VyperNode): id: str = ... -class Expr(VyperNode): ... +class Expr(VyperNode): + value: VyperNode = ... class UnaryOp(VyperNode): op: VyperNode = ... diff --git a/vyper/context/types/event.py b/vyper/context/types/event.py index 64dfdf5f29..d683733755 100644 --- a/vyper/context/types/event.py +++ b/vyper/context/types/event.py @@ -6,7 +6,7 @@ from vyper.context.types.bases import DataLocation from vyper.context.types.utils import get_type_from_annotation from vyper.context.validation.utils import validate_expected_type -from vyper.exceptions import StructureException +from vyper.exceptions import EventDeclarationException, StructureException # NOTE: This implementation isn't as polished as it could be, because it will be # replaced with a new struct-style syntax prior to the next release. @@ -41,13 +41,17 @@ def from_annotation( is_public: bool = False, ) -> "Event": arguments = OrderedDict() - indexed = [] + indexed: List = [] validate_call_args(node, 1) if not isinstance(node.args[0], vy_ast.Dict): raise StructureException("Invalid event declaration syntax", node.args[0]) for key, value in zip(node.args[0].keys, node.args[0].values): if isinstance(value, vy_ast.Call) and value.get("func.id") == "indexed": validate_call_args(value, 1) + if indexed.count(True) == 3: + raise EventDeclarationException( + "Event cannot have more than three indexed arguments", value + ) indexed.append(True) value = value.args[0] else: diff --git a/vyper/context/types/function.py b/vyper/context/types/function.py index f6add6279b..fef6062909 100644 --- a/vyper/context/types/function.py +++ b/vyper/context/types/function.py @@ -120,6 +120,7 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionType": def from_FunctionDef( cls, node: vy_ast.FunctionDef, + is_constant: Optional[bool] = None, is_public: Optional[bool] = None, include_defaults: Optional[bool] = True, ) -> "ContractFunctionType": @@ -142,6 +143,8 @@ def from_FunctionDef( ContractFunctionType """ kwargs: Dict[str, Any] = {} + if is_constant is not None: + kwargs["is_constant"] = is_constant if is_public is not None: kwargs["is_public"] = is_public diff --git a/vyper/context/types/indexable/sequence.py b/vyper/context/types/indexable/sequence.py index a3d67def64..85e601a437 100644 --- a/vyper/context/types/indexable/sequence.py +++ b/vyper/context/types/indexable/sequence.py @@ -92,7 +92,7 @@ class TupleDefinition(_SequenceDefinition): def __init__(self, value_type: Tuple[BaseTypeDefinition, ...]) -> None: # always use the most restrictive location re: modification location = sorted((i.location for i in value_type), key=lambda k: k.value)[-1] - is_constant = next((True for i in value_type if getattr(i, 'is_constant', None)), False) + is_constant = next((True for i in value_type if getattr(i, "is_constant", None)), False) super().__init__( value_type, # type: ignore len(value_type), diff --git a/vyper/context/types/meta/interface.py b/vyper/context/types/meta/interface.py index 27f207df3d..b44d1025a3 100644 --- a/vyper/context/types/meta/interface.py +++ b/vyper/context/types/meta/interface.py @@ -153,5 +153,15 @@ def _get_class_functions(base_node: vy_ast.ClassDef) -> OrderedDict: for node in base_node.body: if not isinstance(node, vy_ast.FunctionDef): raise StructureException("Interfaces can only contain function definitions", node) - functions[node.name] = ContractFunctionType.from_FunctionDef(node, is_public=True) + + if len(node.body) != 1 or node.body[0].get("value.id") not in ("constant", "modifying"): + raise StructureException( + "Interface function must be set as constant or modifying", + node.body[0] if node.body else node, + ) + + is_constant = bool(node.body[0].value.id == "constant") + fn = ContractFunctionType.from_FunctionDef(node, is_constant=is_constant, is_public=True) + functions[node.name] = fn + return functions diff --git a/vyper/context/validation/local.py b/vyper/context/validation/local.py index 3e88f03b47..b330f96dc4 100644 --- a/vyper/context/validation/local.py +++ b/vyper/context/validation/local.py @@ -1,4 +1,5 @@ import copy +from typing import Optional from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args @@ -35,7 +36,7 @@ ) -def validate_functions(vy_module): +def validate_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function-level namespaces.""" @@ -44,14 +45,14 @@ def validate_functions(vy_module): for node in vy_module.get_children(vy_ast.FunctionDef): with namespace.enter_scope(): try: - FunctionNodeVisitor(node, namespace) + FunctionNodeVisitor(vy_module, node, namespace) except VyperException as e: err_list.append(e) err_list.raise_if_not_empty() -def _is_terminus_node(node): +def _is_terminus_node(node: vy_ast.VyperNode) -> bool: if getattr(node, "_is_terminus", None): return True if isinstance(node, vy_ast.Expr) and isinstance(node.value, vy_ast.Call): @@ -73,6 +74,24 @@ def check_for_terminus(node_list: list) -> bool: return False +def _check_iterator_assign( + target_node: vy_ast.VyperNode, search_node: vy_ast.VyperNode +) -> Optional[vy_ast.VyperNode]: + similar_nodes = [ + n + for n in search_node.get_descendants(type(target_node)) + if vy_ast.compare_nodes(target_node, n) + ] + + for node in similar_nodes: + # raise if the node is the target of an assignment statement + assign_node = node.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) + if assign_node and node in assign_node.target.get_descendants(include_self=True): + return node + + return None + + class FunctionNodeVisitor(VyperNodeVisitorBase): ignored_types = ( @@ -83,7 +102,10 @@ class FunctionNodeVisitor(VyperNodeVisitorBase): ) scope_name = "function" - def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None: + def __init__( + self, vyper_module: vy_ast.Module, fn_node: vy_ast.FunctionDef, namespace: dict + ) -> None: + self.vyper_module = vyper_module self.fn_node = fn_node self.namespace = namespace self.func = namespace["self"].get_member(fn_node.name, fn_node) @@ -269,17 +291,35 @@ def visit_For(self, node): raise StructureException("Cannot iterate over a nested list", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - # find references to the iterated node within the for-loop body - similar_nodes = [ - n - for n in node.get_descendants(type(node.iter)) - if vy_ast.compare_nodes(node.iter, n) - ] - for n in similar_nodes: - # raise if the node is the target of an assignment statement - assign = n.get_ancestor((vy_ast.Assign, vy_ast.AugAssign)) - if assign and n in assign.target.get_descendants(include_self=True): - raise ConstancyViolation("Cannot alter array during iteration", n) + # check for references to the iterated value within the body of the loop + assign = _check_iterator_assign(node.iter, node) + if assign: + raise ConstancyViolation("Cannot modify array during iteration", assign) + + if node.iter.get("value.id") == "self": + # check if iterated value may be modified by function calls inside the loop + iter_name = node.iter.attr + for call_node in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}): + fn_name = call_node.func.attr + + fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": fn_name})[0] + if _check_iterator_assign(node.iter, fn_node): + # check for direct modification + raise ConstancyViolation( + f"Cannot call '{fn_name}' inside for loop, it potentially " + f"modifies iterated storage variable '{iter_name}'", + call_node, + ) + + for name in self.namespace["self"].members[fn_name].recursive_calls: + # check for indirect modification + fn_node = self.vyper_module.get_children(vy_ast.FunctionDef, {"name": name})[0] + if _check_iterator_assign(node.iter, fn_node): + raise ConstancyViolation( + f"Cannot call '{fn_name}' inside for loop, it may call to '{name}' " + f"which potentially modifies iterated storage variable '{iter_name}'", + call_node, + ) for type_ in type_list: type_ = copy.deepcopy(type_) diff --git a/vyper/context/validation/module.py b/vyper/context/validation/module.py index 4b66002120..11954aa21a 100644 --- a/vyper/context/validation/module.py +++ b/vyper/context/validation/module.py @@ -1,6 +1,6 @@ import importlib import pkgutil -from typing import Union +from typing import Optional, Union import vyper.interfaces from vyper import ast as vy_ast @@ -12,6 +12,7 @@ from vyper.context.validation.base import VyperNodeVisitorBase from vyper.context.validation.utils import validate_expected_type from vyper.exceptions import ( + CallViolation, CompilerPanic, ConstancyViolation, ExceptionList, @@ -30,6 +31,19 @@ def add_module_namespace(vy_module: vy_ast.Module, interface_codes: InterfaceDic ModuleNodeVisitor(vy_module, interface_codes, namespace) +def _find_cyclic_call(fn_names: list, self_members: dict) -> Optional[list]: + if fn_names[-1] not in self_members: + return None + internal_calls = self_members[fn_names[-1]].internal_calls + for name in internal_calls: + if name in fn_names: + return fn_names + [name] + sequence = _find_cyclic_call(fn_names + [name], self_members) + if sequence: + return sequence + return None + + class ModuleNodeVisitor(VyperNodeVisitorBase): scope_name = "module" @@ -57,6 +71,51 @@ def __init__( if count == len(module_nodes): err_list.raise_if_not_empty() + # get list of internal function calls made by each function + call_function_names = set() + self_members = namespace["self"].members + for node in self.ast.get_children(vy_ast.FunctionDef): + call_function_names.add(node.name) + self_members[node.name].internal_calls = set( + i.func.attr for i in node.get_descendants(vy_ast.Call, {"func.value.id": "self"}) + ) + if node.name in self_members[node.name].internal_calls: + self_node = node.get_descendants( + vy_ast.Attribute, {"value.id": "self", "attr": node.name} + )[0] + raise CallViolation(f"Function '{node.name}' calls into itself", self_node) + + for fn_name in sorted(call_function_names): + + if fn_name not in self_members: + # the referenced function does not exist - this is an issue, but we'll report + # it later when parsing the function so we can give more meaningful output + continue + + # check for circular function calls + sequence = _find_cyclic_call([fn_name], self_members) + if sequence is not None: + nodes = [] + for i in range(len(sequence) - 1): + fn_node = self.ast.get_children(vy_ast.FunctionDef, {"name": sequence[i]})[0] + call_node = fn_node.get_descendants( + vy_ast.Attribute, {"value.id": "self", "attr": sequence[i + 1]} + )[0] + nodes.append(call_node) + + raise CallViolation("Contract contains cyclic function call", *nodes) + + # get complete list of functions that are reachable from this function + function_set = set(i for i in self_members[fn_name].internal_calls if i in self_members) + while True: + expanded = set(x for i in function_set for x in self_members[i].internal_calls) + expanded |= function_set + if expanded == function_set: + break + function_set = expanded + + self_members[fn_name].recursive_calls = function_set + def visit_AnnAssign(self, node): name = node.get("target.id") if name is None: diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 651d8fe198..7bd1fb746b 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -11,6 +11,7 @@ class ExceptionList(list): raised Exception to this list and call raise_if_not_empty once the task is completed. """ + def raise_if_not_empty(self): if len(self) == 1: raise self[0] @@ -31,7 +32,8 @@ class VyperException(Exception): This exception is not raised directly. Other exceptions inherit it in order to display source annotations in the error string. """ - def __init__(self, message='Error Message not found.', item=None): + + def __init__(self, message="Error Message not found.", *items): """ Exception initializer. @@ -39,60 +41,70 @@ def __init__(self, message='Error Message not found.', item=None): --------- message : str Error message to display with the exception. - item : VyperNode | tuple, optional - Vyper ast node or tuple of (lineno, col_offset) indicating where - the exception occured. + *items : VyperNode | tuple, optional + Vyper ast node(s) indicating where the exception occured. Source + annotation is generated in the order the nodes are given. A single + tuple of (lineno, col_offset) is also understood to support the old + API, but new exceptions should not use this approach. """ self.message = message self.lineno = None self.col_offset = None - if isinstance(item, tuple): - self.lineno, self.col_offset = item[:2] - elif hasattr(item, 'lineno'): - self.lineno = item.lineno - self.col_offset = item.col_offset - self.source_code = item.full_source_code + if len(items) == 1 and isinstance(items[0], tuple): + self.lineno, self.col_offset = items[0][:2] + else: + self.nodes = items + if items: + self.source_code = items[0].full_source_code - def with_annotation(self, node): + def with_annotation(self, *nodes): """ Creates a copy of this exception with a modified source annotation. Arguments --------- - node : VyperNode - AST node to obtain the source offset from. + *node : VyperNode + AST node(s) to use in the annotation. Returns ------- - A copy of the exception with the new offset applied. + A copy of the exception with the new node offset(s) applied. """ exc = copy.copy(self) - exc.lineno = node.lineno - exc.col_offset = node.col_offset - exc.source_code = node.full_source_code + exc.source_code = nodes[0].full_source_code + exc.nodes = nodes return exc def __str__(self): - lineno, col_offset = self.lineno, self.col_offset + from vyper import ast as vy_ast + from vyper.utils import annotate_source_code - if lineno is not None and hasattr(self, 'source_code'): - from vyper.utils import annotate_source_code + if not hasattr(self, "source_code"): + if self.lineno is not None and self.col_offset is not None: + return f"line {self.lineno}:{self.col_offset} {self.message}" + else: + return self.message + msg = f"{self.message}\n" + for node in self.nodes: source_annotation = annotate_source_code( self.source_code, - lineno, - col_offset, + node.lineno, + node.col_offset, context_lines=VYPER_ERROR_CONTEXT_LINES, line_numbers=VYPER_ERROR_LINE_NUMBERS, ) - col_offset_str = '' if col_offset is None else str(col_offset) - return f'line {lineno}:{col_offset_str} {self.message}\n{source_annotation}' - elif lineno is not None and col_offset is not None: - return f'line {lineno}:{col_offset} {self.message}' + if isinstance(node, vy_ast.VyperNode): + fn_node = node.get_ancestor(vy_ast.FunctionDef) + if fn_node: + msg += f"function '{fn_node.name}', " - return self.message + col_offset_str = "" if node.col_offset is None else str(node.col_offset) + msg += f"line {node.lineno}:{col_offset_str} \n{source_annotation}\n" + + return msg class SyntaxException(VyperException): @@ -231,6 +243,7 @@ class VyperInternalException(Exception): Internal exceptions are raised as a means of passing information between compiler processes. They should never be exposed to the user. """ + def __init__(self, message=""): self.message = message diff --git a/vyper/parser/stmt.py b/vyper/parser/stmt.py index a5f5e2bf15..55f9b95ad7 100644 --- a/vyper/parser/stmt.py +++ b/vyper/parser/stmt.py @@ -68,7 +68,7 @@ def parse_AnnAssign(self): constants=self.context.constants, ) varname = self.stmt.target.id - pos = self.context.new_variable(varname, typ) + pos = self.context.new_variable(varname, typ, pos=self.stmt) if self.stmt.value is None: return diff --git a/vyper/signatures/event_signature.py b/vyper/signatures/event_signature.py index 48ac9d7530..3aef4a40ed 100644 --- a/vyper/signatures/event_signature.py +++ b/vyper/signatures/event_signature.py @@ -1,9 +1,5 @@ from vyper import ast as vy_ast -from vyper.exceptions import ( - EventDeclarationException, - InvalidType, - VariableDeclarationException, -) +from vyper.exceptions import EventDeclarationException, TypeCheckFailure from vyper.signatures.function_signature import VariableRecord from vyper.types import ByteArrayType, canonicalize_type, get_size_of_type from vyper.utils import ( @@ -45,14 +41,11 @@ def from_declaration(cls, code, global_ctx): for i in range(len(keys)): typ = values[i] if not isinstance(keys[i], vy_ast.Name): - raise EventDeclarationException( - 'Invalid key type, expected a valid name.', - keys[i], - ) + raise TypeCheckFailure('Invalid key type, expected a valid name.') if not isinstance(typ, (vy_ast.Name, vy_ast.Call, vy_ast.Subscript)): - raise EventDeclarationException('Invalid event argument type.', typ) + raise TypeCheckFailure('Invalid event argument type.') if isinstance(typ, vy_ast.Call) and not isinstance(typ.func, vy_ast.Name): - raise EventDeclarationException('Invalid event argument type', typ) + raise TypeCheckFailure('Invalid event argument type') arg = keys[i].id arg_item = keys[i] is_indexed = False @@ -68,14 +61,11 @@ def from_declaration(cls, code, global_ctx): if isinstance(typ, vy_ast.Subscript) and getattr(typ.value, 'id', None) == 'bytes' and typ.slice.value.n > 32 and is_indexed: # noqa: E501 raise EventDeclarationException("Indexed arguments are limited to 32 bytes") if topics_count > 4: - raise EventDeclarationException( - f"Maximum of 3 topics {topics_count - 1} given", - arg, - ) + raise TypeCheckFailure("Too many indexed arguments") if not isinstance(arg, str): - raise VariableDeclarationException("Argument name invalid", arg) + raise TypeCheckFailure("Argument name invalid") if not typ: - raise InvalidType("Argument must have type", arg) + raise TypeCheckFailure("Argument must have type") check_valid_varname( arg, global_ctx._structs, @@ -84,10 +74,7 @@ def from_declaration(cls, code, global_ctx): error_prefix="Event argument name invalid or reserved.", ) if arg in (x.name for x in args): - raise VariableDeclarationException( - "Duplicate function argument name: " + arg, - arg_item, - ) + raise TypeCheckFailure(f"Duplicate function argument name: {arg}") # Can struct be logged? parsed_type = global_ctx.parse_type(typ, None) args.append(VariableRecord(arg, pos, parsed_type, False))