Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: frontend for dynamic array functions #7

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.bases import BaseTypeDefinition, DataLocation, StorageSlot
from vyper.semantics.types.indexable.sequence import TupleDefinition
from vyper.semantics.types.indexable.sequence import DynamicArrayDefinition, TupleDefinition
from vyper.semantics.types.user.struct import StructDefinition
from vyper.semantics.types.utils import (
StringEnum,
Expand Down Expand Up @@ -503,6 +503,41 @@ def to_abi_dict(self) -> List[Dict]:
return [abi_dict]


class MemberFunctionDefinition(BaseTypeDefinition):
"""
Member function type definition.
This class has no corresponding primitive.
"""

_is_callable = True

def __init__(
self, underlying_type: BaseTypeDefinition, name: str, min_arg_count: int, max_arg_count: int
) -> None:
super().__init__(DataLocation.UNSET)
self.underlying_type = underlying_type
self.name = name
self.min_arg_count = min_arg_count
self.max_arg_count = max_arg_count

def __repr__(self):
return f"{self.underlying_type._id} member function '{self.name}'"

def fetch_call_return(self, node: vy_ast.Call) -> Optional[BaseTypeDefinition]:
validate_call_args(node, (self.min_arg_count, self.max_arg_count))

if isinstance(self.underlying_type, DynamicArrayDefinition):
if self.name == "append":
return None

elif self.name == "pop":
value_type = self.underlying_type.value_type
return value_type

raise CallViolation("Function does not exist on given type", node)


def _generate_abi_type(type_definition, name=""):
if isinstance(type_definition, StructDefinition):
return {
Expand Down
13 changes: 12 additions & 1 deletion vyper/semantics/types/indexable/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseTypeDefinition,
DataLocation,
IndexableTypeDefinition,
MemberTypeDefinition,
)
from vyper.semantics.types.value.numeric import Uint256Definition

Expand Down Expand Up @@ -116,7 +117,7 @@ def compare_type(self, other):
return self.value_type.compare_type(other.value_type)


class DynamicArrayDefinition(_SequenceDefinition):
class DynamicArrayDefinition(_SequenceDefinition, MemberTypeDefinition):
"""
Dynamic array type definition.
"""
Expand All @@ -132,6 +133,13 @@ def __init__(
) -> None:
super().__init__(value_type, length, "DynArray", location, is_immutable, is_public)

# Adding members here as otherwise MemberFunctionDefinition is not yet defined
# if added as _type_members
from vyper.semantics.types.function import MemberFunctionDefinition

self.add_member("append", MemberFunctionDefinition(self, "append", 0, 1))
self.add_member("pop", MemberFunctionDefinition(self, "pop", 0, 0))

def __repr__(self):
return f"DynArray[{self.value_type}, {self.length}]"

Expand Down Expand Up @@ -167,6 +175,9 @@ def compare_type(self, other):
return False
return self.value_type.compare_type(other.value_type)

def fetch_call_return(self, node: vy_ast.Call) -> None:
pass


class DynamicArrayPrimitive(BasePrimitive):
_id = "DynArray"
Expand Down
10 changes: 7 additions & 3 deletions vyper/semantics/validation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from vyper.exceptions import StructureException
from vyper.semantics.types import ArrayDefinition
from vyper.semantics.types.bases import BaseTypeDefinition
from vyper.semantics.types.function import ContractFunction
from vyper.semantics.types.function import ContractFunction, MemberFunctionDefinition
from vyper.semantics.types.user.event import Event
from vyper.semantics.types.user.struct import StructPrimitive
from vyper.semantics.validation.utils import (
Expand Down Expand Up @@ -125,9 +125,9 @@ def visit_BoolOp(self, node, type_):

def visit_Call(self, node, type_):
call_type = get_exact_type_from_node(node.func)
node._metadata["type"] = type_ or call_type.fetch_call_return(node)
node_type = type_ or call_type.fetch_call_return(node)
node._metadata["type"] = node_type
self.visit(node.func)

if isinstance(call_type, (Event, ContractFunction)):
# events and internal function calls
for arg, arg_type in zip(node.args, list(call_type.arguments.values())):
Expand All @@ -136,6 +136,10 @@ def visit_Call(self, node, type_):
# literal structs
for value, arg_type in zip(node.args[0].values, list(call_type.members.values())):
self.visit(value, arg_type)
elif isinstance(call_type, MemberFunctionDefinition):
if node_type:
for arg in node.args:
self.visit(arg, node_type.value_type)
elif node.func.id not in ("empty", "range"):
# builtin functions
for arg in node.args:
Expand Down
15 changes: 11 additions & 4 deletions vyper/semantics/validation/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.abstract import IntegerAbstractType
from vyper.semantics.types.bases import DataLocation
from vyper.semantics.types.function import ContractFunction, FunctionVisibility, StateMutability
from vyper.semantics.types.function import (
ContractFunction,
FunctionVisibility,
MemberFunctionDefinition,
StateMutability,
)
from vyper.semantics.types.indexable.sequence import (
ArrayDefinition,
DynamicArrayDefinition,
Expand Down Expand Up @@ -444,7 +449,6 @@ def visit_For(self, node):
def visit_Expr(self, node):
if not isinstance(node.value, vy_ast.Call):
raise StructureException("Expressions without assignment are disallowed", node)

fn_type = get_exact_type_from_node(node.value.func)
if isinstance(fn_type, Event):
raise StructureException("To call an event you must use the `log` statement", node)
Expand All @@ -463,9 +467,12 @@ def visit_Expr(self, node):
raise StateAccessViolation(
f"Cannot call any function from a {self.func.mutability.value} function", node
)

return_value = fn_type.fetch_call_return(node.value)
if return_value and not isinstance(fn_type, ContractFunction):
if (
return_value
and not isinstance(fn_type, MemberFunctionDefinition)
and not isinstance(fn_type, ContractFunction)
):
raise StructureException(
f"Function '{fn_type._id}' cannot be called without assigning the result", node
)
Expand Down