diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index c1ff7674bb..892f0bcabc 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -4,7 +4,7 @@ from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException -from vyper.semantics.analysis.utils import validate_expected_type +from vyper.semantics.analysis.utils import infer_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T from vyper.utils import unsigned_to_signed @@ -55,7 +55,7 @@ def foo(a: uint256, b: uint256) -> uint256: # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case # more types are added). - validate_expected_type(new_node, UINT256_T) + _ = infer_type(new_node, UINT256_T) # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except OverflowException: # here: check the wrapped value matches runtime @@ -81,7 +81,7 @@ def foo(a: int256, b: uint256) -> int256: vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() - validate_expected_type(new_node, INT256_T) # force bounds check + _ = infer_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. # compile-time will throw on OOB, runtime will wrap. except (InvalidType, OverflowException): diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index 6e6cf4c662..3d25b435da 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -10,7 +10,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import TYPE_T, KwargSettings, VyperType from vyper.semantics.types.utils import type_from_annotation @@ -99,7 +99,7 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N # for its side effects (will throw if is not a type) type_from_annotation(arg) else: - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type) def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 7575f4d77e..345b59197a 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -54,7 +54,7 @@ get_common_types, get_exact_type_from_node, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.types import ( TYPE_T, @@ -508,8 +508,7 @@ def infer_arg_types(self, node, expected_return_typ=None): ret = [] prev_typeclass = None for arg in node.args: - validate_expected_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) - arg_t = get_possible_types_from_node(arg).pop() + arg_t = infer_type(arg, (BytesT.any(), StringT.any(), BytesM_T.any())) current_typeclass = "String" if isinstance(arg_t, StringT) else "Bytes" if prev_typeclass and current_typeclass != prev_typeclass: raise TypeMismatch( @@ -865,7 +864,7 @@ def infer_kwarg_types(self, node): "Output type must be one of integer, bytes32 or address", node.keywords[0].value ) output_typedef = TYPE_T(output_type) - node.keywords[0].value._metadata["type"] = output_typedef + #node.keywords[0].value._metadata["type"] = output_typedef else: output_typedef = TYPE_T(BYTES32_T) @@ -2376,8 +2375,8 @@ def infer_kwarg_types(self, node): ret = {} for kwarg in node.keywords: kwarg_name = kwarg.arg - validate_expected_type(kwarg.value, self._kwargs[kwarg_name].typ) - ret[kwarg_name] = get_exact_type_from_node(kwarg.value) + typ = infer_type(kwarg.value, self._kwargs[kwarg_name].typ) + ret[kwarg_name] = typ return ret def fetch_call_return(self, node): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d96215ede0..77fa57c074 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -26,7 +26,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation @@ -254,7 +254,7 @@ def _validate_revert_reason(self, msg_node: vy_ast.VyperNode) -> None: self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) elif not (isinstance(msg_node, vy_ast.Name) and msg_node.id == "UNREACHABLE"): try: - validate_expected_type(msg_node, StringT(1024)) + _ = infer_type(msg_node, StringT(1024)) except TypeMismatch as e: raise InvalidType("revert reason must fit within String[1024]") from e self.expr_visitor.visit(msg_node, get_exact_type_from_node(msg_node)) @@ -563,15 +563,10 @@ def scope_name(self): def visit(self, node, typ): if typ is not VOID_TYPE and not isinstance(typ, TYPE_T): - validate_expected_type(node, typ) + infer_type(node, expected_type=typ) - # recurse and typecheck in case we are being fed the wrong type for - # some reason. super().visit(node, typ) - # annotate - node._metadata["type"] = typ - if not isinstance(typ, TYPE_T): info = get_expr_info(node) # get_expr_info fills in node._expr_info @@ -793,7 +788,7 @@ def visit_Tuple(self, node: vy_ast.Tuple, typ: VyperType) -> None: # don't recurse; can't annotate AST children of type definition return - # these guarantees should be provided by validate_expected_type + # these guarantees should be provided by infer_type assert isinstance(typ, TupleT) assert len(node.elements) == len(typ.member_types) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index e50c3e6d6f..787ec82c15 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -526,7 +526,7 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) # performs validate_expected_type + ExprVisitor().visit(node.value, type_) # performs type validation if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index f1f0f48a86..c889e6ab75 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -224,7 +224,7 @@ def types_from_BinOp(self, node): # can be different types types_list = get_possible_types_from_node(node.left) # check rhs is unsigned integer - validate_expected_type(node.right, IntegerT.unsigneds()) + _ = infer_type(node.right, IntegerT.unsigneds()) else: types_list = get_common_types(node.left, node.right) @@ -319,7 +319,7 @@ def types_from_Constant(self, node): raise InvalidLiteral(f"Could not determine type for literal value '{node.value}'", node) def types_from_IfExp(self, node): - validate_expected_type(node.test, BoolT()) + _ = infer_type(node.test, expected_type=BoolT()) types_list = get_common_types(node.body, node.orelse) if not types_list: @@ -529,14 +529,14 @@ def _validate_literal_array(node, expected): for item in node.elements: try: - validate_expected_type(item, expected.value_type) + _ = infer_type(item, expected.value_type) except (InvalidType, TypeMismatch): return False return True -def validate_expected_type(node, expected_type): +def infer_type(node, expected_type): """ Validate that the given node matches the expected type(s) @@ -551,8 +551,15 @@ def validate_expected_type(node, expected_type): Returns ------- - None + The inferred type. The inferred type must be a concrete type which + is compatible with the expected type (although the expected type may + be generic). """ + ret = _infer_type_helper(node, expected_type) + node._metadata["type"] = ret + return ret + +def _infer_type_helper(node, expected_type): if not isinstance(expected_type, tuple): expected_type = (expected_type,) @@ -561,15 +568,15 @@ def validate_expected_type(node, expected_type): for t in possible_tuple_types: if len(t.member_types) != len(node.elements): continue - for item_ast, item_type in zip(node.elements, t.member_types): + ret = [] + for item_ast, expected_item_type in zip(node.elements, t.member_types): try: - validate_expected_type(item_ast, item_type) - return + item_t = infer_type(item_ast, expected_type=expected_item_type) + ret.append(item_t) except VyperException: - pass - else: - # fail block - pass + break # go to fail block + else: + return TupleT(tuple(ret)) given_types = _ExprAnalyser().get_possible_types_from_node(node) @@ -579,11 +586,11 @@ def validate_expected_type(node, expected_type): if not isinstance(expected, (DArrayT, SArrayT)): continue if _validate_literal_array(node, expected): - return + return expected else: for given, expected in itertools.product(given_types, expected_type): if expected.compare_type(given): - return + return given # validation failed, prepare a meaningful error message if len(expected_type) > 1: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 62f9c60585..4f4fc82e5c 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,7 +27,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, - validate_expected_type, + infer_type, ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType @@ -542,7 +542,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: raise CallViolation("Cannot send ether to nonpayable function", kwarg_node) for arg, expected in zip(node.args, self.argument_types): - validate_expected_type(arg, expected) + infer_type(arg, expected) # TODO this should be moved to validate_call_args for kwarg in node.keywords: @@ -553,7 +553,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: f"`{kwarg.arg}=` specified but {self.name}() does not return anything", kwarg.value, ) - validate_expected_type(kwarg.value, kwarg_settings.typ) + infer_type(kwarg.value, kwarg_settings.typ) if kwarg_settings.require_literal: if not isinstance(kwarg.value, vy_ast.Constant): raise InvalidType( @@ -730,7 +730,7 @@ def _parse_args( value = funcdef.args.defaults[i - n_positional_args] if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): raise StateAccessViolation("Value must be literal or environment variable", value) - validate_expected_type(value, type_) + infer_type(value, expected_type=type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) argnames.add(argname) @@ -788,7 +788,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]: assert len(node.args) == len(self.arg_types) # validate_call_args postcondition for arg, expected_type in zip(node.args, self.arg_types): # CMC 2022-04-01 this should probably be in the validation module - validate_expected_type(arg, expected_type) + infer_type(arg, expected_type=expected_type) return self.return_type diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 86840f4f91..0ef052a3da 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -13,7 +13,7 @@ from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.utils import ( check_modifiability, - validate_expected_type, + infer_type, validate_unique_method_ids, ) from vyper.semantics.data_locations import DataLocation @@ -83,8 +83,8 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": def _ctor_arg_types(self, node): validate_call_args(node, 1) - validate_expected_type(node.args[0], AddressT()) - return [AddressT()] + typ = infer_type(node.args[0], AddressT()) + return [typ] def _ctor_kwarg_types(self, node): return {} diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 635a1631a2..9dec62e136 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -35,9 +35,9 @@ def getter_signature(self) -> Tuple[Tuple, Optional[VyperType]]: def validate_index_type(self, node): # TODO: break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type - validate_expected_type(node, self.key_type) + infer_type(node, self.key_type) class HashMapT(_SubscriptableT): @@ -125,7 +125,7 @@ def count(self): def validate_index_type(self, node): # TODO break this cycle - from vyper.semantics.analysis.utils import validate_expected_type + from vyper.semantics.analysis.utils import infer_type if isinstance(node, vy_ast.Int): if node.value < 0: @@ -133,7 +133,7 @@ def validate_index_type(self, node): if node.value >= self.length: raise ArrayIndexException("Index out of range", node) - validate_expected_type(node, IntegerT.any()) + infer_type(node, IntegerT.any()) def get_subscripted_type(self, node): return self.value_type diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92a455e3d8..c3f169ac8d 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -16,7 +16,7 @@ ) from vyper.semantics.analysis.base import Modifiability from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions -from vyper.semantics.analysis.utils import check_modifiability, validate_expected_type +from vyper.semantics.analysis.utils import check_modifiability, infer_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import VyperType from vyper.semantics.types.subscriptable import HashMapT @@ -270,7 +270,7 @@ def from_EventDef(cls, base_node: vy_ast.EventDef) -> "EventT": def _ctor_call_return(self, node: vy_ast.Call) -> None: validate_call_args(node, len(self.arguments)) for arg, expected in zip(node.args, self.arguments.values()): - validate_expected_type(arg, expected) + infer_type(arg, expected) def to_toplevel_abi_dict(self) -> list[dict]: return [ @@ -412,7 +412,7 @@ def _ctor_call_return(self, node: vy_ast.Call) -> "StructT": key, ) - validate_expected_type(value, members.pop(key.id)) + infer_type(value, members.pop(key.id)) if members: raise VariableDeclarationException(