Skip to content

Commit

Permalink
feat: allow range(x, y, bound=N) (#3679)
Browse files Browse the repository at this point in the history
- allow range where both start and end arguments are variables, so long
  as a bound is supplied

- ban range expressions of the form `range(x, x + N)` since the new form
  is cleaner and supersedes it.

- also do a bit of refactoring of the codegen for range

---------

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
DanielSchiavini and charles-cooper authored Dec 24, 2023
1 parent 2e41873 commit 5319cfb
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 147 deletions.
8 changes: 5 additions & 3 deletions docs/control-structures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,11 @@ Another use of range can be with ``START`` and ``STOP`` bounds.
Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``.

Finally, it is possible to use ``range`` with runtime `start` and `stop` values as long as a constant `bound` value is provided.
In this case, Vyper checks at runtime that `end - start <= bound`.
``N`` must be a compile-time constant.

.. code-block:: python
for i in range(a, a + N):
for i in range(start, end, bound=N):
...
``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert.
19 changes: 13 additions & 6 deletions tests/functional/codegen/features/iteration/test_for_in_list.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from decimal import Decimal

import pytest
Expand Down Expand Up @@ -700,13 +701,16 @@ def foo():
""",
StateAccessViolation,
),
"""
(
"""
@external
def foo():
a: int128 = 6
for i in range(a,a-3):
pass
""",
StateAccessViolation,
),
# invalid argument length
(
"""
Expand Down Expand Up @@ -789,10 +793,13 @@ def test_for() -> int128:
),
]

BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE]
for_code_regex = re.compile(r"for .+ in (.*):")
bad_code_names = [
f"{i} {for_code_regex.search(code).group(1)}" for i, (code, _) in enumerate(BAD_CODE)
]


@pytest.mark.parametrize("code", BAD_CODE)
def test_bad_code(assert_compile_failed, get_contract, code):
err = StructureException
if not isinstance(code, str):
code, err = code
@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names)
def test_bad_code(assert_compile_failed, get_contract, code, err):
assert_compile_failed(lambda: get_contract(code), err)
116 changes: 107 additions & 9 deletions tests/functional/codegen/features/iteration/test_for_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,102 @@ def repeat(n: uint256) -> uint256:
c.repeat(7)


def test_range_bound_constant_end(get_contract, tx_failed):
code = """
@external
def repeat(n: uint256) -> uint256:
x: uint256 = 0
for i in range(n, 7, bound=6):
x += i + 1
return x
"""
c = get_contract(code)
for n in range(1, 5):
assert c.repeat(n) == sum(i + 1 for i in range(n, 7))

# check assertion for `start <= end`
with tx_failed():
c.repeat(8)
# check assertion for `start + bound <= end`
with tx_failed():
c.repeat(0)


def test_range_bound_two_args(get_contract, tx_failed):
code = """
@external
def repeat(n: uint256) -> uint256:
x: uint256 = 0
for i in range(1, n, bound=6):
x += i + 1
return x
"""
c = get_contract(code)
for n in range(1, 8):
assert c.repeat(n) == sum(i + 1 for i in range(1, n))

# check assertion for `start <= end`
with tx_failed():
c.repeat(0)

# check codegen inserts assertion for `start + bound <= end`
with tx_failed():
c.repeat(8)


def test_range_bound_two_runtime_args(get_contract, tx_failed):
code = """
@external
def repeat(start: uint256, end: uint256) -> uint256:
x: uint256 = 0
for i in range(start, end, bound=6):
x += i
return x
"""
c = get_contract(code)
for n in range(0, 7):
assert c.repeat(0, n) == sum(range(0, n))
assert c.repeat(n, n * 2) == sum(range(n, n * 2))

# check assertion for `start <= end`
with tx_failed():
c.repeat(1, 0)
with tx_failed():
c.repeat(7, 0)
with tx_failed():
c.repeat(8, 7)

# check codegen inserts assertion for `start + bound <= end`
with tx_failed():
c.repeat(0, 7)
with tx_failed():
c.repeat(14, 21)


def test_range_overflow(get_contract, tx_failed):
code = """
@external
def get_last(start: uint256, end: uint256) -> uint256:
x: uint256 = 0
for i in range(start, end, bound=6):
x = i
return x
"""
c = get_contract(code)
UINT_MAX = 2**256 - 1
assert c.get_last(UINT_MAX, UINT_MAX) == 0 # initial value of x

for n in range(1, 6):
assert c.get_last(UINT_MAX - n, UINT_MAX) == UINT_MAX - 1

# check for `start + bound <= end`, overflow cases
for n in range(1, 7):
with tx_failed():
c.get_last(UINT_MAX - n, 0)
with tx_failed():
c.get_last(UINT_MAX, UINT_MAX - n)


def test_digit_reverser(get_contract_with_gas_estimation):
digit_reverser = """
@external
Expand Down Expand Up @@ -89,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ):
@external
def sum(frm: {typ}, to: {typ}) -> {typ}:
out: {typ} = 0
for i in range(frm, frm + 101):
for i in range(frm, frm + 101, bound=101):
if i == to:
break
out = out + i
Expand Down Expand Up @@ -146,26 +242,28 @@ def foo(a: {typ}) -> {typ}:
assert c.foo(100) == 31337


# test that we can get to the upper range of an integer
@pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"])
def test_for_range_edge(get_contract, typ):
"""
Check that we can get to the upper range of an integer.
Note that to avoid overflow in the bounds check for range(),
we need to calculate i+1 inside the loop.
"""
code = f"""
@external
def test():
found: bool = False
x: {typ} = max_value({typ})
for i in range(x, x + 1):
if i == max_value({typ}):
for i in range(x - 1, x, bound=1):
if i + 1 == max_value({typ}):
found = True
assert found
found = False
x = max_value({typ}) - 1
for i in range(x, x + 2):
if i == max_value({typ}):
for i in range(x - 1, x + 1, bound=2):
if i + 1 == max_value({typ}):
found = True
assert found
"""
c = get_contract(code)
Expand All @@ -178,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ):
@external
def test():
x: {typ} = max_value({typ})
for i in range(x, x+2):
for i in range(x, x + 2, bound=2):
pass
"""
c = get_contract(code)
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/codegen/integration/test_crowdfund.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def finalize():
@external
def refund():
ind: int128 = self.refundIndex
for i in range(ind, ind + 30):
for i in range(ind, ind + 30, bound=30):
if i >= self.nextFunderIndex:
self.refundIndex = self.nextFunderIndex
return
Expand Down Expand Up @@ -147,7 +147,7 @@ def finalize():
@external
def refund():
ind: int128 = self.refundIndex
for i in range(ind, ind + 30):
for i in range(ind, ind + 30, bound=30):
if i >= self.nextFunderIndex:
self.refundIndex = self.nextFunderIndex
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,6 @@ def foo():
""",
"""
@external
def foo(x: int128):
y: int128 = 7
for i in range(x, x + y):
pass
""",
"""
@external
def foo():
x: String[100] = "these bytes are nо gооd because the o's are from the Russian alphabet"
""",
Expand Down
Loading

0 comments on commit 5319cfb

Please sign in to comment.