Skip to content

Commit

Permalink
Merge pull request #2108 from iamdefinitelyahuman/fix-array-index-typ…
Browse files Browse the repository at this point in the history
…echeck

Array index typecheck
  • Loading branch information
fubuloubu authored Jul 16, 2020
2 parents 71568de + 24b7a3d commit dc8ebc0
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 5 deletions.
86 changes: 86 additions & 0 deletions tests/functional/context/validation/test_array_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.context.validation import validate_semantics
from vyper.exceptions import (
ArrayIndexException,
InvalidReference,
InvalidType,
TypeMismatch,
UndeclaredDefinition,
)


@pytest.mark.parametrize("value", ["address", "Bytes[10]", "decimal", "bool"])
def test_type_mismatch(namespace, value):
code = f"""
a: uint256[3]
@internal
def foo(b: {value}):
self.a[b] = 12
"""
vyper_module = parse_to_ast(code)
with pytest.raises(TypeMismatch):
validate_semantics(vyper_module, {})


@pytest.mark.parametrize("value", ["1.0", "0.0", "'foo'", "0x00", "b'\x01'", "False"])
def test_invalid_literal(namespace, value):
code = f"""
a: uint256[3]
@internal
def foo():
self.a[{value}] = 12
"""
vyper_module = parse_to_ast(code)
with pytest.raises(InvalidType):
validate_semantics(vyper_module, {})


@pytest.mark.parametrize("value", [-1, 3, -(2 ** 127), 2 ** 127 - 1, 2 ** 256 - 1])
def test_out_of_bounds(namespace, value):
code = f"""
a: uint256[3]
@internal
def foo():
self.a[{value}] = 12
"""
vyper_module = parse_to_ast(code)
with pytest.raises(ArrayIndexException):
validate_semantics(vyper_module, {})


@pytest.mark.parametrize("value", ["b", "self.b"])
def test_undeclared_definition(namespace, value):
code = f"""
a: uint256[3]
@internal
def foo():
self.a[{value}] = 12
"""
vyper_module = parse_to_ast(code)
with pytest.raises(UndeclaredDefinition):
validate_semantics(vyper_module, {})


@pytest.mark.parametrize("value", ["a", "foo", "int128"])
def test_invalid_reference(namespace, value):
code = f"""
a: uint256[3]
@internal
def foo():
self.a[{value}] = 12
"""
vyper_module = parse_to_ast(code)
with pytest.raises(InvalidReference):
validate_semantics(vyper_module, {})
16 changes: 16 additions & 0 deletions vyper/context/types/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,26 @@ def compare_type(self, other):
pass
return isinstance(other, type(self))

def __repr__(self):
value = super().__repr__()
if value == object.__repr__(self):
# use `_description` when no parent class overrides the default python repr
return self._description
return value


class ArrayValueAbstractType(AbstractDataType):
"""
Abstract data class for single-value types occupying multiple memory slots.
"""

_description = "fixed size bytes array or string"


class BytesAbstractType(AbstractDataType):
"""Abstract data class for bytes types (bytes32, bytes[])."""

_description = "bytes"
_id = "bytes"


Expand All @@ -34,10 +44,14 @@ class NumericAbstractType(AbstractDataType):
Abstract data class for numeric types (capable of arithmetic).
"""

_description = "numeric value"


class IntegerAbstractType(NumericAbstractType):
"""Abstract data class for integer numeric types (int128, uint256)."""

_description = "integer"


class FixedAbstractType(NumericAbstractType):
"""
Expand All @@ -47,3 +61,5 @@ class FixedAbstractType(NumericAbstractType):
still be used to expect decimal values in anticipation of multiple decimal
types in a future release.
"""

_description = "decimal"
7 changes: 5 additions & 2 deletions vyper/context/types/indexable/sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple

from vyper import ast as vy_ast
from vyper.context import validation
from vyper.context.types.abstract import IntegerAbstractType
from vyper.context.types.bases import (
BaseTypeDefinition,
Expand Down Expand Up @@ -69,8 +70,10 @@ def get_index_type(self, node):
if isinstance(node, vy_ast.Int):
if node.value < 0:
raise ArrayIndexException("Vyper does not support negative indexing", node)
if node.value < 0 or node.value >= self.length:
if node.value >= self.length:
raise ArrayIndexException("Index out of range", node)
else:
validation.utils.validate_expected_type(node, IntegerAbstractType())
return self.value_type

def compare_type(self, other):
Expand Down Expand Up @@ -109,7 +112,7 @@ def get_index_type(self, node):
raise InvalidType("Tuple indexes must be literals", node)
if node.value < 0:
raise ArrayIndexException("Vyper does not support negative indexing", node)
if node.value < 0 or node.value >= self.length:
if node.value >= self.length:
raise ArrayIndexException("Index out of range", node)
return self.value_type[node.value]

Expand Down
9 changes: 6 additions & 3 deletions vyper/context/validation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,12 @@ def validate_expected_type(node, expected_type):
if not isinstance(node, (vy_ast.List, vy_ast.Tuple)) and node.get_descendants(
vy_ast.Name, include_self=True
):
raise TypeMismatch(
f"Given reference has type {given_types[0]}, expected {expected_str}", node
)
given = given_types[0]
if isinstance(given, type) and types.BasePrimitive in given.mro():
raise InvalidReference(
f"'{given._id}' is a type - expected a literal or variable", node
)
raise TypeMismatch(f"Given reference has type {given}, expected {expected_str}", node)
else:
if len(given_types) == 1:
given_str = str(given_types[0])
Expand Down

0 comments on commit dc8ebc0

Please sign in to comment.