diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 26170591ff..d385eaf3b2 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -16,7 +16,7 @@ ) from vyper.semantics.namespace import get_namespace from vyper.semantics.types.bases import BaseTypeDefinition, DataLocation, StorageSlot -from vyper.semantics.types.indexable.sequence import TupleDefinition +from vyper.semantics.types.indexable.sequence import DynamicArrayDefinition, TupleDefinition from vyper.semantics.types.user.struct import StructDefinition from vyper.semantics.types.utils import ( StringEnum, @@ -503,6 +503,43 @@ def to_abi_dict(self) -> List[Dict]: return [abi_dict] +class MemberFunctionDefinition(BaseTypeDefinition): + """ + Member function type definition. + + This class has no corresponding primitive. + """ + + _is_callable = True + + def __init__( + self, underlying_type: BaseTypeDefinition, name: str, min_arg_count: int, max_arg_count: int + ) -> None: + super().__init__(DataLocation.UNSET) + print("Underlying type: " + str(underlying_type)) + self.underlying_type = underlying_type + self.name = name + self.min_arg_count = min_arg_count + self.max_arg_count = max_arg_count + + def __repr__(self): + return f"{self.underlying_type._id} member function '{self.name}'" + + def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: + print(self.__repr__()) + validate_call_args(node, (self.min_arg_count, self.max_arg_count)) + + if isinstance(self.underlying_type, DynamicArrayDefinition): + if self.name == "append": + return None + + elif self.name == "pop": + value_type = self.underlying_type.value_type + return value_type + + raise CallViolation("Function does not exist on given type", node) + + def _generate_abi_type(type_definition, name=""): if isinstance(type_definition, StructDefinition): return { diff --git a/vyper/semantics/types/indexable/sequence.py b/vyper/semantics/types/indexable/sequence.py index 890209999b..89a32fae00 100644 --- a/vyper/semantics/types/indexable/sequence.py +++ b/vyper/semantics/types/indexable/sequence.py @@ -2,8 +2,7 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType -from vyper.ast.validation import validate_call_args -from vyper.exceptions import ArrayIndexException, CallViolation, InvalidType, StructureException +from vyper.exceptions import ArrayIndexException, InvalidType, StructureException from vyper.semantics import validation from vyper.semantics.types.abstract import IntegerAbstractType from vyper.semantics.types.bases import ( @@ -134,10 +133,12 @@ def __init__( ) -> None: super().__init__(value_type, length, "DynArray", location, is_immutable, is_public) - # Adding members here as otherwise DynamicArrayFunctionDefinition is not yet defined + # Adding members here as otherwise MemberFunctionDefinition is not yet defined # if added as _type_members - self.add_member("append", DynamicArrayFunctionDefinition("append", 0, 1)) - self.add_member("pop", DynamicArrayFunctionDefinition("pop", 0, 0)) + from vyper.semantics.types.function import MemberFunctionDefinition + + self.add_member("append", MemberFunctionDefinition(self, "append", 0, 1)) + self.add_member("pop", MemberFunctionDefinition(self, "pop", 0, 0)) def __repr__(self): return f"DynArray[{self.value_type}, {self.length}]" @@ -218,41 +219,6 @@ def from_annotation( ) -class DynamicArrayFunctionDefinition(BaseTypeDefinition): - """ - Dynamic array function type definition. - - This class has no corresponding primitive. - """ - - _is_callable = True - - def __init__(self, name: str, min_arg_count: int, max_arg_count: int) -> None: - super().__init__(DataLocation.UNSET) - self.name = name - self.min_arg_count = min_arg_count - self.max_arg_count = max_arg_count - - def __repr__(self): - return f"dynamic array function '{self.name}'" - - def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]: - validate_call_args(node, (self.min_arg_count, self.max_arg_count)) - - dynarray_definition = validation.utils.get_exact_type_from_node(node.func.value) - value_type = dynarray_definition.value_type - dynarray_length = dynarray_definition.length - - if self.name == "append": - validation.utils.validate_expected_type(node.args[0], value_type) - return DynamicArrayDefinition(value_type, dynarray_length + 1) - - elif self.name == "pop": - return value_type - - raise CallViolation("Function on dynamic array does not exist", node) - - class TupleDefinition(_SequenceDefinition): """ Tuple type definition. diff --git a/vyper/semantics/validation/annotation.py b/vyper/semantics/validation/annotation.py index 990f0805ff..966a0c8ed1 100644 --- a/vyper/semantics/validation/annotation.py +++ b/vyper/semantics/validation/annotation.py @@ -2,8 +2,7 @@ from vyper.exceptions import StructureException from vyper.semantics.types import ArrayDefinition from vyper.semantics.types.bases import BaseTypeDefinition -from vyper.semantics.types.indexable.sequence import DynamicArrayFunctionDefinition -from vyper.semantics.types.function import ContractFunction +from vyper.semantics.types.function import ContractFunction, MemberFunctionDefinition from vyper.semantics.types.user.event import Event from vyper.semantics.types.user.struct import StructPrimitive from vyper.semantics.validation.utils import ( @@ -137,9 +136,10 @@ def visit_Call(self, node, type_): # literal structs for value, arg_type in zip(node.args[0].values, list(call_type.members.values())): self.visit(value, arg_type) - elif isinstance(call_type, DynamicArrayFunctionDefinition): - for arg in node.args: - self.visit(arg, node_type.value_type) + elif isinstance(call_type, MemberFunctionDefinition): + if node_type: + for arg in node.args: + self.visit(arg, node_type.value_type) elif node.func.id not in ("empty", "range"): # builtin functions for arg in node.args: diff --git a/vyper/semantics/validation/local.py b/vyper/semantics/validation/local.py index f52f8a84c8..854b50ed6f 100644 --- a/vyper/semantics/validation/local.py +++ b/vyper/semantics/validation/local.py @@ -25,11 +25,15 @@ from vyper.semantics.namespace import get_namespace from vyper.semantics.types.abstract import IntegerAbstractType from vyper.semantics.types.bases import DataLocation -from vyper.semantics.types.function import ContractFunction, FunctionVisibility, StateMutability +from vyper.semantics.types.function import ( + ContractFunction, + FunctionVisibility, + MemberFunctionDefinition, + StateMutability, +) from vyper.semantics.types.indexable.sequence import ( ArrayDefinition, DynamicArrayDefinition, - DynamicArrayFunctionDefinition, TupleDefinition, ) from vyper.semantics.types.user.event import Event @@ -464,8 +468,11 @@ def visit_Expr(self, node): f"Cannot call any function from a {self.func.mutability.value} function", node ) return_value = fn_type.fetch_call_return(node.value) - if return_value and not isinstance(fn_type, DynamicArrayFunctionDefinition) and \ - not isinstance(fn_type, ContractFunction): + if ( + return_value + and not isinstance(fn_type, MemberFunctionDefinition) + and not isinstance(fn_type, ContractFunction) + ): raise StructureException( f"Function '{fn_type._id}' cannot be called without assigning the result", node )