From be2b7f427bf980a0baf52cdd010d83231824ad3a Mon Sep 17 00:00:00 2001 From: Benny Date: Fri, 9 Sep 2022 16:12:02 +1000 Subject: [PATCH] feat: allow constant and immutable variables to be declared public (#3024) move more annotating to VariableDecl AST node to mirror what it would look like if vyper's parser matched its grammar more closely. refactor some downstream code. modify ast.expansion to handle constant and immutable variables. Co-authored-by: Charles Cooper --- tests/parser/syntax/test_constants.py | 8 ++++ tests/parser/syntax/test_immutables.py | 8 ++++ tests/parser/syntax/test_public.py | 8 ++++ vyper/ast/expansion.py | 26 +++++++------ vyper/ast/folding.py | 12 ++---- vyper/ast/nodes.py | 40 ++++++++++++-------- vyper/codegen/global_context.py | 36 +++++------------- vyper/semantics/types/function.py | 7 ++-- vyper/semantics/types/user/interface.py | 2 +- vyper/semantics/types/utils.py | 2 +- vyper/semantics/validation/data_positions.py | 7 ++-- vyper/semantics/validation/module.py | 13 ++----- 12 files changed, 89 insertions(+), 80 deletions(-) diff --git a/tests/parser/syntax/test_constants.py b/tests/parser/syntax/test_constants.py index 7897c0805c..374f93e68b 100644 --- a/tests/parser/syntax/test_constants.py +++ b/tests/parser/syntax/test_constants.py @@ -8,6 +8,7 @@ NamespaceCollision, StateAccessViolation, StructureException, + SyntaxException, VariableDeclarationException, ) @@ -108,6 +109,13 @@ def foo() -> uint256: """, StateAccessViolation, ), + ( + # constant(public()) banned + """ +S: constant(public(uint256)) = 3 + """, + SyntaxException, + ), ] diff --git a/tests/parser/syntax/test_immutables.py b/tests/parser/syntax/test_immutables.py index 855df921b4..3bed282644 100644 --- a/tests/parser/syntax/test_immutables.py +++ b/tests/parser/syntax/test_immutables.py @@ -50,6 +50,14 @@ def __init__(_value: uint256): VALUE = _value * 3 VALUE = VALUE + 1 """, + # immutable(public()) banned + """ +VALUE: immutable(public(uint256)) + +@external +def __init__(_value: uint256): + VALUE = _value * 3 + """, ] diff --git a/tests/parser/syntax/test_public.py b/tests/parser/syntax/test_public.py index a756f211ef..fd0058cab8 100644 --- a/tests/parser/syntax/test_public.py +++ b/tests/parser/syntax/test_public.py @@ -7,6 +7,14 @@ x: public(int128) """, """ +x: public(constant(int128)) = 0 +y: public(immutable(int128)) + +@external +def __init__(): + y = 0 + """, + """ x: public(int128) y: public(int128) z: public(int128) diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py index 721a66145f..812aa68f1e 100644 --- a/vyper/ast/expansion.py +++ b/vyper/ast/expansion.py @@ -30,20 +30,24 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None: Top-level Vyper AST node. """ - for node in vyper_module.get_children(vy_ast.VariableDecl, {"annotation.func.id": "public"}): + for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}): func_type = node._metadata["func_type"] input_types, return_type = func_type.get_signature() input_nodes = [] - # use the annotation node as a base to build the input args and return type - # starting with `args[0]` to remove the surrounding `public()` call` - annotation = copy.copy(node.annotation.args[0]) - - # the base return statement is an `Attribute` node, e.g. `self.` - # for each input type we wrap it in a `Subscript` to access a specific member - return_stmt: vy_ast.VyperNode = vy_ast.Attribute( - value=vy_ast.Name(id="self"), attr=func_type.name - ) + # use the annotation node to build the input args and return type + annotation = copy.copy(node.annotation) + + return_stmt: vy_ast.VyperNode + # constants just return a value + if node.is_constant: + return_stmt = node.value + elif node.is_immutable: + return_stmt = vy_ast.Name(id=func_type.name) + else: + # the base return statement is an `Attribute` node, e.g. `self.` + # for each input type we wrap it in a `Subscript` to access a specific member + return_stmt = vy_ast.Attribute(value=vy_ast.Name(id="self"), attr=func_type.name) return_stmt._metadata["type"] = node._metadata["type"] for i, type_ in enumerate(input_types): @@ -100,7 +104,7 @@ def remove_unused_statements(vyper_module: vy_ast.Module) -> None: """ # constant declarations - values were substituted within the AST during folding - for node in vyper_module.get_children(vy_ast.VariableDecl, {"annotation.func.id": "constant"}): + for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}): vyper_module.remove_from_body(node) # `implements: interface` statements - validated during type checking diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index c1c58a04f7..edcec476c0 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -174,23 +174,19 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int: if not isinstance(node.target, vy_ast.Name): # left-hand-side of assignment is not a variable continue - if node.get("annotation.func.id") != "constant": + if not node.is_constant: # annotation is not wrapped in `constant(...)` continue # Extract type definition from propagated annotation - constant_annotation = node.get("annotation.args")[0] + type_ = None try: - type_ = ( - get_type_from_annotation(constant_annotation, DataLocation.UNSET) - if constant_annotation - else None - ) + type_ = get_type_from_annotation(node.annotation, DataLocation.UNSET) except UnknownType: # handle user-defined types e.g. structs - it's OK to not # propagate the type annotation here because user-defined # types can be unambiguously inferred at typechecking time - type_ = None + pass changed_nodes += replace_constant( vyper_module, node.target.id, node.value, False, type_=type_ diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 5e29789ace..a06dbe30de 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -7,6 +7,7 @@ from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS from vyper.exceptions import ( + ArgumentException, CompilerPanic, InvalidLiteral, InvalidOperation, @@ -1282,23 +1283,30 @@ def __init__(self, *args, **kwargs): self.is_public = False self.is_immutable = False - if isinstance(self.annotation, Call): - # the annotation is a function call, e.g. `foo: constant(uint256)` - call_name = self.annotation.get("func.id") - if call_name == "constant": - # declaring a constant - self.is_constant = True - - elif call_name == "public": - # declaring a public variable - self.is_public = True - - elif call_name == "immutable": - # declaring an immutable variable - self.is_immutable = True + def _check_args(annotation, call_name): + # do the same thing as `validate_call_args` + # (can't be imported due to cyclic dependency) + if len(annotation.args) != 1: + raise ArgumentException("Invalid number of arguments to `{call_name}`:", self) + + # the annotation is a "function" call, e.g. + # `foo: public(constant(uint256))` + # pretend we were parsing actual Vyper AST. annotation would be + # TYPE | PUBLIC "(" TYPE | ((IMMUTABLE | CONSTANT) "(" TYPE ")") ")" + if self.annotation.get("func.id") == "public": + _check_args(self.annotation, "public") + self.is_public = True + # unwrap one layer + self.annotation = self.annotation.args[0] + + if self.annotation.get("func.id") in ("immutable", "constant"): + _check_args(self.annotation, self.annotation.func.id) + setattr(self, f"is_{self.annotation.func.id}", True) + # unwrap one layer + self.annotation = self.annotation.args[0] - else: - _raise_syntax_exc("Invalid scope for variable declaration", self.annotation) + if isinstance(self.annotation, Call): + _raise_syntax_exc("Invalid scope for variable declaration", self.annotation) class AugAssign(VyperNode): diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py index 690e14a76b..d06c5b1b22 100644 --- a/vyper/codegen/global_context.py +++ b/vyper/codegen/global_context.py @@ -140,13 +140,6 @@ def make_contract(node: "vy_ast.InterfaceDef") -> list: raise StructureException("Invalid contract reference", item) return _defs - @staticmethod - def get_call_func_name(item): - if isinstance(item.annotation, vy_ast.Call) and isinstance( - item.annotation.func, vy_ast.Name - ): - return item.annotation.func.id - def add_globals_and_events(self, item): if self._nonrentrant_counter: @@ -157,30 +150,21 @@ def add_globals_and_events(self, item): raise StructureException("Invalid global variable name", item.target) # Handle constants. - if self.get_call_func_name(item) == "constant": + if item.is_constant: return # references to `len(self._globals)` are remnants of deprecated code, retained # to preserve existing interfaces while we complete a larger refactor. location # and size of storage vars is handled in `vyper.context.validation.data_positions` - if self.get_call_func_name(item) == "public": - typ = self.parse_type(item.annotation.args[0]) - self._globals[item.target.id] = VariableRecord( - item.target.id, len(self._globals), typ, True - ) - elif self.get_call_func_name(item) == "immutable": - typ = self.parse_type(item.annotation.args[0]) - self._globals[item.target.id] = VariableRecord( - item.target.id, len(self._globals), typ, False, is_immutable=True - ) - - elif isinstance(item.annotation, (vy_ast.Name, vy_ast.Call, vy_ast.Subscript)): - typ = self.parse_type(item.annotation) - self._globals[item.target.id] = VariableRecord( - item.target.id, len(self._globals), typ, True - ) - else: - raise InvalidType("Invalid global type specified", item) + typ = self.parse_type(item.annotation) + is_immutable = item.is_immutable + self._globals[item.target.id] = VariableRecord( + item.target.id, + len(self._globals), + typ, + mutable=not is_immutable, + is_immutable=is_immutable, + ) @property def interface_names(self): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index f1e9a593a0..764479d4b8 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -382,9 +382,10 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio ------- ContractFunction """ - if not isinstance(node.annotation, vy_ast.Call): - raise CompilerPanic("Annotation must be a call to public()") - type_ = get_type_from_annotation(node.annotation.args[0], location=DataLocation.STORAGE) + if not node.is_public: + raise CompilerPanic("getter generated for non-public function") + location = DataLocation.CODE if node.is_immutable else DataLocation.STORAGE + type_ = get_type_from_annotation(node.annotation, location=location) arguments, return_type = type_.get_signature() args_dict: OrderedDict = OrderedDict() for item in arguments: diff --git a/vyper/semantics/types/user/interface.py b/vyper/semantics/types/user/interface.py index 8fb33e1ed4..fb16c50b4f 100644 --- a/vyper/semantics/types/user/interface.py +++ b/vyper/semantics/types/user/interface.py @@ -201,7 +201,7 @@ def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[OrderedDict, Dict # only keep the `ContractFunction` with the longest set of input args continue functions[node.name] = func - for node in base_node.get_children(vy_ast.VariableDecl, {"annotation.func.id": "public"}): + for node in base_node.get_children(vy_ast.VariableDecl, {"is_public": True}): name = node.target.id if name in functions: raise NamespaceCollision( diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 1aa2c0e316..6d2a6bdec6 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -156,7 +156,7 @@ def get_type_from_annotation( Arguments --------- node : VyperNode - Vyper ast node from the `annotation` member of a `VariableDef` or `AnnAssign` node. + Vyper ast node from the `annotation` member of a `VariableDecl` or `AnnAssign` node. Returns ------- diff --git a/vyper/semantics/validation/data_positions.py b/vyper/semantics/validation/data_positions.py index deec697c2f..5bece514eb 100644 --- a/vyper/semantics/validation/data_positions.py +++ b/vyper/semantics/validation/data_positions.py @@ -179,7 +179,8 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: for node in vyper_module.get_children(vy_ast.VariableDecl): - if node.get("annotation.func.id") == "immutable": + # skip non-storage variables + if node.is_constant or node.is_immutable: continue type_ = node.target._metadata["type"] @@ -210,9 +211,7 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: ret = {} offset = 0 - for node in vyper_module.get_children( - vy_ast.VariableDecl, filters={"annotation.func.id": "immutable"} - ): + for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): type_ = node._metadata["type"] type_.set_position(CodeOffset(offset)) diff --git a/vyper/semantics/validation/module.py b/vyper/semantics/validation/module.py index 52d897e840..be1bf4c316 100644 --- a/vyper/semantics/validation/module.py +++ b/vyper/semantics/validation/module.py @@ -4,7 +4,6 @@ import vyper.builtin_interfaces from vyper import ast as vy_ast -from vyper.ast.validation import validate_call_args from vyper.exceptions import ( CallViolation, CompilerPanic, @@ -153,18 +152,12 @@ def visit_VariableDecl(self, node): if name is None: raise VariableDeclarationException("Invalid module-level assignment", node) - annotation = node.annotation - # remove the outer call node, to handle cases such as `public(map(..))` - if node.is_public or node.is_immutable or node.is_constant: - validate_call_args(annotation, 1) - annotation = annotation.args[0] - if node.is_public: # generate function type and add to metadata - # we need this when builing the public getter + # we need this when building the public getter node._metadata["func_type"] = ContractFunction.getter_from_VariableDecl(node) - elif node.is_immutable: + if node.is_immutable: # mutability is checked automatically preventing assignment # outside of the constructor, here we just check a value is assigned, # not necessarily where @@ -185,7 +178,7 @@ def visit_VariableDecl(self, node): data_loc = DataLocation.CODE if node.is_immutable else DataLocation.STORAGE type_definition = get_type_from_annotation( - annotation, data_loc, node.is_constant, node.is_public, node.is_immutable + node.annotation, data_loc, node.is_constant, node.is_public, node.is_immutable ) node._metadata["type"] = type_definition