diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py index 8707b4c326..0d61a8f8f8 100644 --- a/tests/functional/semantics/analysis/test_for_loop.py +++ b/tests/functional/semantics/analysis/test_for_loop.py @@ -1,7 +1,12 @@ import pytest from vyper.ast import parse_to_ast -from vyper.exceptions import ImmutableViolation, TypeMismatch +from vyper.exceptions import ( + ArgumentException, + ImmutableViolation, + StateAccessViolation, + TypeMismatch, +) from vyper.semantics.analysis import validate_semantics @@ -59,6 +64,34 @@ def bar(): validate_semantics(vyper_module, {}) +def test_bad_keywords(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, boundddd=10): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(ArgumentException): + validate_semantics(vyper_module, {}) + + +def test_bad_bound(namespace): + code = """ + +@internal +def bar(n: uint256): + x: uint256 = 0 + for i in range(n, bound=n): + x += i + """ + vyper_module = parse_to_ast(code) + with pytest.raises(StateAccessViolation): + validate_semantics(vyper_module, {}) + + def test_modify_iterator_function_call(namespace): code = """ diff --git a/tests/parser/features/iteration/test_for_range.py b/tests/parser/features/iteration/test_for_range.py index 30f4bb87e3..395dd28231 100644 --- a/tests/parser/features/iteration/test_for_range.py +++ b/tests/parser/features/iteration/test_for_range.py @@ -14,6 +14,23 @@ def repeat(z: int128) -> int128: assert c.repeat(9) == 54 +def test_range_bound(get_contract, assert_tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(7): + assert c.repeat(n) == sum(range(n)) + + # check codegen inserts assertion for n greater than bound + assert_tx_failed(lambda: c.repeat(7)) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 91d45f4916..86ea1813ea 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -258,11 +258,17 @@ def _parse_For_range(self): arg0 = self.stmt.iter.args[0] num_of_args = len(self.stmt.iter.args) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) + for s in self.stmt.iter.keywords or [] + } + # Type 1 for, e.g. for i in range(10): ... if num_of_args == 1: - arg0_val = self._get_range_const_value(arg0) + n = Expr.parse_value_expr(arg0, self.context) start = IRnode.from_list(0, typ=iter_typ) - rounds = arg0_val + rounds = n + rounds_bound = kwargs.get("bound", rounds) # Type 2 for, e.g. for i in range(100, 110): ... elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: @@ -270,6 +276,7 @@ def _parse_For_range(self): arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) start = IRnode.from_list(arg0_val, typ=iter_typ) rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) + rounds_bound = rounds # Type 3 for, e.g. for i in range(x, x + 10): ... else: @@ -278,9 +285,10 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) _, hi = start.typ.int_bounds start = clamp("le", start, hi + 1 - rounds) + rounds_bound = rounds - r = rounds if isinstance(rounds, int) else rounds.value - if r < 1: + bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value + if bound < 1: return varname = self.stmt.target.id @@ -294,7 +302,10 @@ def _parse_For_range(self): loop_body.append(["mstore", iptr, i]) loop_body.append(parse_body(self.stmt.body, self.context)) - ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds, loop_body]) + # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound. + # if we ever want to remove that, we need to manually add the assertion + # where it makes sense. + ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds_bound, loop_body]) del self.context.forvars[varname] return ir_node diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 5e29bad0b5..bba3b34515 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -413,9 +413,8 @@ def _height_of(witharg): ) # stack: i, rounds, rounds_bound # assert rounds <= rounds_bound - # TODO this runtime assertion should never fail for + # TODO this runtime assertion shouldn't fail for # internally generated repeats. - # maybe drop it or jump to 0xFE o.extend(["DUP2", "GT"] + _assert_false()) # stack: i, rounds diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py index 3ea0319b54..d309f102cd 100644 --- a/vyper/semantics/analysis/annotation.py +++ b/vyper/semantics/analysis/annotation.py @@ -95,6 +95,9 @@ def visit_For(self, node): iter_type = node.target._metadata["type"] for a in node.iter.args: self.expr_visitor.visit(a, iter_type) + for a in node.iter.keywords: + if a.arg == "bound": + self.expr_visitor.visit(a.value, iter_type) class ExpressionAnnotationVisitor(_AnnotationVisitorBase): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c99b582ad3..c0c05325f2 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -346,17 +346,30 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - validate_call_args(node.iter, (1, 2)) + validate_call_args(node.iter, (1, 2), kwargs=["bound"]) args = node.iter.args + kwargs = {s.arg: s.value for s in node.iter.keywords or []} if len(args) == 1: # range(CONSTANT) - if not isinstance(args[0], vy_ast.Num): - raise StateAccessViolation("Value must be a literal", node) - if args[0].value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - validate_expected_type(args[0], IntegerT.any()) - type_list = get_possible_types_from_node(args[0]) + n = args[0] + bound = kwargs.pop("bound", None) + validate_expected_type(n, IntegerT.any()) + + if bound is None: + if not isinstance(n, vy_ast.Num): + raise StateAccessViolation("Value must be a literal", n) + if n.value <= 0: + raise StructureException("For loop must have at least 1 iteration", args[0]) + type_list = get_possible_types_from_node(n) + + else: + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("bound must be at least 1", args[0]) + type_list = get_common_types(n, bound) + else: validate_expected_type(args[0], IntegerT.any()) type_list = get_common_types(*args)