Skip to content

Commit

Permalink
rename validate_expected_type to infer_type and have it return a type
Browse files Browse the repository at this point in the history
it also tags the node with the inferred type
  • Loading branch information
charles-cooper committed Feb 10, 2024
1 parent 8ccacb3 commit 6b9fff2
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 50 deletions.
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tests.utils import parse_and_fold
from vyper.exceptions import InvalidType, OverflowException
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import infer_type
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
from vyper.utils import unsigned_to_signed

Expand Down Expand Up @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256:
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
# more types are added).
validate_expected_type(new_node, UINT256_T)
_ = infer_type(new_node, UINT256_T)
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except OverflowException: # here: check the wrapped value matches runtime
Expand All @@ -81,7 +81,7 @@ def foo(a: int256, b: uint256) -> int256:
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
validate_expected_type(new_node, INT256_T) # force bounds check
_ = infer_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
except (InvalidType, OverflowException):
Expand Down
4 changes: 2 additions & 2 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation
Expand Down Expand Up @@ -99,7 +99,7 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N
# for its side effects (will throw if is not a type)
type_from_annotation(arg)
else:
validate_expected_type(arg, expected_type)
infer_type(arg, expected_type)

def _validate_arg_types(self, node: vy_ast.Call) -> None:
num_args = len(self._inputs) # the number of args the signature indicates
Expand Down
11 changes: 5 additions & 6 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
get_common_types,
get_exact_type_from_node,
get_possible_types_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.types import (
TYPE_T,
Expand Down Expand Up @@ -508,8 +508,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
ret = []
prev_typeclass = None
for arg in node.args:
validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
arg_t = get_possible_types_from_node(arg).pop()
arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any()))
current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes"
if prev_typeclass and current_typeclass != prev_typeclass:
raise TypeMismatch(
Expand Down Expand Up @@ -865,7 +864,7 @@ def infer_kwarg_types(self, node):
"Output type must be one of integer, bytes32 or address", node.keywords[0].value
)
output_typedef = TYPE_T(output_type)
node.keywords[0].value._metadata["type"] = output_typedef
#node.keywords[0].value._metadata["type"] = output_typedef
else:
output_typedef = TYPE_T(BYTES32_T)

Expand Down Expand Up @@ -2376,8 +2375,8 @@ def infer_kwarg_types(self, node):
ret = {}
for kwarg in node.keywords:
kwarg_name = kwarg.arg
validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = get_exact_type_from_node(kwarg.value)
typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ)
ret[kwarg_name] = typ
return ret

def fetch_call_return(self, node):
Expand Down
13 changes: 4 additions & 9 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.data_locations import DataLocation

Expand Down Expand Up @@ -254,7 +254,7 @@ def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None:
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"):
try:
validate_expected_type(msg_node, StringT(1024))
_ = infer_type(msg_node, StringT(1024))
except TypeMismatch as e:
raise InvalidType("revert reason must fit within String[1024]") from e
self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node))
Expand Down Expand Up @@ -563,15 +563,10 @@ def scope_name(self):

def visit(self, node, typ):
if typ is not VOID_TYPE and not isinstance(typ, TYPE_T):
validate_expected_type(node, typ)
infer_type(node, expected_type=typ)

# recurse and typecheck in case we are being fed the wrong type for
# some reason.
super().visit(node, typ)

# annotate
node._metadata["type"] = typ

if not isinstance(typ, TYPE_T):
info = get_expr_info(node) # get_expr_info fills in node._expr_info

Expand Down Expand Up @@ -793,7 +788,7 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None:
# don't recurse; can't annotate AST children of type definition
return

# these guarantees should be provided by validate_expected_type
# these guarantees should be provided by infer_type
assert isinstance(typ, TupleT)
assert len(node.elements) == len(typ.member_types)

Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def _validate_self_namespace():
if node.is_constant:
assert node.value is not None # checked in VariableDecl.validate()

ExprVisitor().visit(node.value, type_) # performs validate_expected_type
ExprVisitor().visit(node.value, type_) # performs type validation

if not check_modifiability(node.value, Modifiability.CONSTANT):
raise StateAccessViolation("Value must be a literal", node.value)
Expand Down
35 changes: 21 additions & 14 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def types_from_BinOp(self, node):
# can be different types
types_list = get_possible_types_from_node(node.left)
# check rhs is unsigned integer
validate_expected_type(node.right, IntegerT.unsigneds())
_ = infer_type(node.right, IntegerT.unsigneds())
else:
types_list = get_common_types(node.left, node.right)

Expand Down Expand Up @@ -319,7 +319,7 @@ def types_from_Constant(self, node):
raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node)

def types_from_IfExp(self, node):
validate_expected_type(node.test, BoolT())
_ = infer_type(node.test, expected_type=BoolT())
types_list = get_common_types(node.body, node.orelse)

if not types_list:
Expand Down Expand Up @@ -529,14 +529,14 @@ def _validate_literal_array(node, expected):

for item in node.elements:
try:
validate_expected_type(item, expected.value_type)
_ = infer_type(item, expected.value_type)
except (InvalidType, TypeMismatch):
return False

return True


def validate_expected_type(node, expected_type):
def infer_type(node, expected_type):
"""
Validate that the given node matches the expected type(s)
Expand All @@ -551,8 +551,15 @@ def validate_expected_type(node, expected_type):
Returns
-------
None
The inferred type. The inferred type must be a concrete type which
is compatible with the expected type (although the expected type may
be generic).
"""
ret = _infer_type_helper(node, expected_type)
node._metadata["type"] = ret
return ret

def _infer_type_helper(node, expected_type):
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)

Expand All @@ -561,15 +568,15 @@ def validate_expected_type(node, expected_type):
for t in possible_tuple_types:
if len(t.member_types) != len(node.elements):
continue
for item_ast, item_type in zip(node.elements, t.member_types):
ret = []
for item_ast, expected_item_type in zip(node.elements, t.member_types):
try:
validate_expected_type(item_ast, item_type)
return
item_t = infer_type(item_ast, expected_type=expected_item_type)
ret.append(item_t)
except VyperException:
pass
else:
# fail block
pass
break # go to fail block
else:
return TupleT(tuple(ret))

given_types = _ExprAnalyser().get_possible_types_from_node(node)

Expand All @@ -579,11 +586,11 @@ def validate_expected_type(node, expected_type):
if not isinstance(expected, (DArrayT, SArrayT)):
continue
if _validate_literal_array(node, expected):
return
return expected
else:
for given, expected in itertools.product(given_types, expected_type):
if expected.compare_type(given):
return
return given

# validation failed, prepare a meaningful error message
if len(expected_type) > 1:
Expand Down
10 changes: 5 additions & 5 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
validate_expected_type,
infer_type,
)
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
Expand Down Expand Up @@ -542,7 +542,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
raise CallViolation("Cannot send ether to nonpayable function", kwarg_node)

for arg, expected in zip(node.args, self.argument_types):
validate_expected_type(arg, expected)
infer_type(arg, expected)

# TODO this should be moved to validate_call_args
for kwarg in node.keywords:
Expand All @@ -553,7 +553,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
f"`{kwarg.arg}=` specified but {self.name}() does not return anything",
kwarg.value,
)
validate_expected_type(kwarg.value, kwarg_settings.typ)
infer_type(kwarg.value, kwarg_settings.typ)
if kwarg_settings.require_literal:
if not isinstance(kwarg.value, vy_ast.Constant):
raise InvalidType(
Expand Down Expand Up @@ -730,7 +730,7 @@ def _parse_args(
value = funcdef.args.defaults[i - n_positional_args]
if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT):
raise StateAccessViolation("Value must be literal or environment variable", value)
validate_expected_type(value, type_)
infer_type(value, expected_type=type_)
keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg))

argnames.add(argname)
Expand Down Expand Up @@ -788,7 +788,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
assert len(node.args) == len(self.arg_types) # validate_call_args postcondition
for arg, expected_type in zip(node.args, self.arg_types):
# CMC 2022-04-01 this should probably be in the validation module
validate_expected_type(arg, expected_type)
infer_type(arg, expected_type=expected_type)

return self.return_type

Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vyper.semantics.analysis.base import Modifiability, VarInfo
from vyper.semantics.analysis.utils import (
check_modifiability,
validate_expected_type,
infer_type,
validate_unique_method_ids,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -83,8 +83,8 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT":

def _ctor_arg_types(self, node):
validate_call_args(node, 1)
validate_expected_type(node.args[0], AddressT())
return [AddressT()]
typ = infer_type(node.args[0], AddressT())
return [typ]

def _ctor_kwarg_types(self, node):
return {}
Expand Down
8 changes: 4 additions & 4 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def getter_signature(self) -> Tuple[Tuple, Optional[VyperType]]:

def validate_index_type(self, node):
# TODO: break this cycle
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import infer_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.

validate_expected_type(node, self.key_type)
infer_type(node, self.key_type)


class HashMapT(_SubscriptableT):
Expand Down Expand Up @@ -125,15 +125,15 @@ def count(self):

def validate_index_type(self, node):
# TODO break this cycle
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.analysis.utils import infer_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.

if isinstance(node, vy_ast.Int):
if node.value < 0:
raise ArrayIndexException("Vyper does not support negative indexing", node)
if node.value >= self.length:
raise ArrayIndexException("Index out of range", node)

validate_expected_type(node, IntegerT.any())
infer_type(node, IntegerT.any())

def get_subscripted_type(self, node):
return self.value_type
Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from vyper.semantics.analysis.base import Modifiability
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type
from vyper.semantics.analysis.utils import check_modifiability, infer_type

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.utils
begins an import cycle.
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.subscriptable import HashMapT
Expand Down Expand Up @@ -270,7 +270,7 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT":
def _ctor_call_return(self, node: vy_ast.Call) -> None:
validate_call_args(node, len(self.arguments))
for arg, expected in zip(node.args, self.arguments.values()):
validate_expected_type(arg, expected)
infer_type(arg, expected)

def to_toplevel_abi_dict(self) -> list[dict]:
return [
Expand Down Expand Up @@ -412,7 +412,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT":
key,
)

validate_expected_type(value, members.pop(key.id))
infer_type(value, members.pop(key.id))

if members:
raise VariableDeclarationException(
Expand Down

0 comments on commit 6b9fff2

Please sign in to comment.