Skip to content

Commit

Permalink
feat: allow constant and immutable variables to be declared public (#…
Browse files Browse the repository at this point in the history
…3024)

move more annotating to VariableDecl AST node to mirror what it would
look like if vyper's parser matched its grammar more closely. refactor
some downstream code. modify ast.expansion to handle constant and
immutable variables.

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
benber86 and charles-cooper authored Sep 9, 2022
1 parent 963c8ed commit be2b7f4
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 80 deletions.
8 changes: 8 additions & 0 deletions tests/parser/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NamespaceCollision,
StateAccessViolation,
StructureException,
SyntaxException,
VariableDeclarationException,
)

Expand Down Expand Up @@ -108,6 +109,13 @@ def foo() -> uint256:
""",
StateAccessViolation,
),
(
# constant(public()) banned
"""
S: constant(public(uint256)) = 3
""",
SyntaxException,
),
]


Expand Down
8 changes: 8 additions & 0 deletions tests/parser/syntax/test_immutables.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def __init__(_value: uint256):
VALUE = _value * 3
VALUE = VALUE + 1
""",
# immutable(public()) banned
"""
VALUE: immutable(public(uint256))
@external
def __init__(_value: uint256):
VALUE = _value * 3
""",
]


Expand Down
8 changes: 8 additions & 0 deletions tests/parser/syntax/test_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
x: public(int128)
""",
"""
x: public(constant(int128)) = 0
y: public(immutable(int128))
@external
def __init__():
y = 0
""",
"""
x: public(int128)
y: public(int128)
z: public(int128)
Expand Down
26 changes: 15 additions & 11 deletions vyper/ast/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,24 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
Top-level Vyper AST node.
"""

for node in vyper_module.get_children(vy_ast.VariableDecl, {"annotation.func.id": "public"}):
for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_public": True}):
func_type = node._metadata["func_type"]
input_types, return_type = func_type.get_signature()
input_nodes = []

# use the annotation node as a base to build the input args and return type
# starting with `args[0]` to remove the surrounding `public()` call`
annotation = copy.copy(node.annotation.args[0])

# the base return statement is an `Attribute` node, e.g. `self.<var_name>`
# for each input type we wrap it in a `Subscript` to access a specific member
return_stmt: vy_ast.VyperNode = vy_ast.Attribute(
value=vy_ast.Name(id="self"), attr=func_type.name
)
# use the annotation node to build the input args and return type
annotation = copy.copy(node.annotation)

return_stmt: vy_ast.VyperNode
# constants just return a value
if node.is_constant:
return_stmt = node.value
elif node.is_immutable:
return_stmt = vy_ast.Name(id=func_type.name)
else:
# the base return statement is an `Attribute` node, e.g. `self.<var_name>`
# for each input type we wrap it in a `Subscript` to access a specific member
return_stmt = vy_ast.Attribute(value=vy_ast.Name(id="self"), attr=func_type.name)
return_stmt._metadata["type"] = node._metadata["type"]

for i, type_ in enumerate(input_types):
Expand Down Expand Up @@ -100,7 +104,7 @@ def remove_unused_statements(vyper_module: vy_ast.Module) -> None:
"""

# constant declarations - values were substituted within the AST during folding
for node in vyper_module.get_children(vy_ast.VariableDecl, {"annotation.func.id": "constant"}):
for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}):
vyper_module.remove_from_body(node)

# `implements: interface` statements - validated during type checking
Expand Down
12 changes: 4 additions & 8 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,19 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int:
if not isinstance(node.target, vy_ast.Name):
# left-hand-side of assignment is not a variable
continue
if node.get("annotation.func.id") != "constant":
if not node.is_constant:
# annotation is not wrapped in `constant(...)`
continue

# Extract type definition from propagated annotation
constant_annotation = node.get("annotation.args")[0]
type_ = None
try:
type_ = (
get_type_from_annotation(constant_annotation, DataLocation.UNSET)
if constant_annotation
else None
)
type_ = get_type_from_annotation(node.annotation, DataLocation.UNSET)
except UnknownType:
# handle user-defined types e.g. structs - it's OK to not
# propagate the type annotation here because user-defined
# types can be unambiguously inferred at typechecking time
type_ = None
pass

changed_nodes += replace_constant(
vyper_module, node.target.id, node.value, False, type_=type_
Expand Down
40 changes: 24 additions & 16 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS
from vyper.exceptions import (
ArgumentException,
CompilerPanic,
InvalidLiteral,
InvalidOperation,
Expand Down Expand Up @@ -1282,23 +1283,30 @@ def __init__(self, *args, **kwargs):
self.is_public = False
self.is_immutable = False

if isinstance(self.annotation, Call):
# the annotation is a function call, e.g. `foo: constant(uint256)`
call_name = self.annotation.get("func.id")
if call_name == "constant":
# declaring a constant
self.is_constant = True

elif call_name == "public":
# declaring a public variable
self.is_public = True

elif call_name == "immutable":
# declaring an immutable variable
self.is_immutable = True
def _check_args(annotation, call_name):
# do the same thing as `validate_call_args`
# (can't be imported due to cyclic dependency)
if len(annotation.args) != 1:
raise ArgumentException("Invalid number of arguments to `{call_name}`:", self)

# the annotation is a "function" call, e.g.
# `foo: public(constant(uint256))`
# pretend we were parsing actual Vyper AST. annotation would be
# TYPE | PUBLIC "(" TYPE | ((IMMUTABLE | CONSTANT) "(" TYPE ")") ")"
if self.annotation.get("func.id") == "public":
_check_args(self.annotation, "public")
self.is_public = True
# unwrap one layer
self.annotation = self.annotation.args[0]

if self.annotation.get("func.id") in ("immutable", "constant"):
_check_args(self.annotation, self.annotation.func.id)
setattr(self, f"is_{self.annotation.func.id}", True)
# unwrap one layer
self.annotation = self.annotation.args[0]

else:
_raise_syntax_exc("Invalid scope for variable declaration", self.annotation)
if isinstance(self.annotation, Call):
_raise_syntax_exc("Invalid scope for variable declaration", self.annotation)


class AugAssign(VyperNode):
Expand Down
36 changes: 10 additions & 26 deletions vyper/codegen/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,6 @@ def make_contract(node: "vy_ast.InterfaceDef") -> list:
raise StructureException("Invalid contract reference", item)
return _defs

@staticmethod
def get_call_func_name(item):
if isinstance(item.annotation, vy_ast.Call) and isinstance(
item.annotation.func, vy_ast.Name
):
return item.annotation.func.id

def add_globals_and_events(self, item):

if self._nonrentrant_counter:
Expand All @@ -157,30 +150,21 @@ def add_globals_and_events(self, item):
raise StructureException("Invalid global variable name", item.target)

# Handle constants.
if self.get_call_func_name(item) == "constant":
if item.is_constant:
return

# references to `len(self._globals)` are remnants of deprecated code, retained
# to preserve existing interfaces while we complete a larger refactor. location
# and size of storage vars is handled in `vyper.context.validation.data_positions`
if self.get_call_func_name(item) == "public":
typ = self.parse_type(item.annotation.args[0])
self._globals[item.target.id] = VariableRecord(
item.target.id, len(self._globals), typ, True
)
elif self.get_call_func_name(item) == "immutable":
typ = self.parse_type(item.annotation.args[0])
self._globals[item.target.id] = VariableRecord(
item.target.id, len(self._globals), typ, False, is_immutable=True
)

elif isinstance(item.annotation, (vy_ast.Name, vy_ast.Call, vy_ast.Subscript)):
typ = self.parse_type(item.annotation)
self._globals[item.target.id] = VariableRecord(
item.target.id, len(self._globals), typ, True
)
else:
raise InvalidType("Invalid global type specified", item)
typ = self.parse_type(item.annotation)
is_immutable = item.is_immutable
self._globals[item.target.id] = VariableRecord(
item.target.id,
len(self._globals),
typ,
mutable=not is_immutable,
is_immutable=is_immutable,
)

@property
def interface_names(self):
Expand Down
7 changes: 4 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
-------
ContractFunction
"""
if not isinstance(node.annotation, vy_ast.Call):
raise CompilerPanic("Annotation must be a call to public()")
type_ = get_type_from_annotation(node.annotation.args[0], location=DataLocation.STORAGE)
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
location = DataLocation.CODE if node.is_immutable else DataLocation.STORAGE
type_ = get_type_from_annotation(node.annotation, location=location)
arguments, return_type = type_.get_signature()
args_dict: OrderedDict = OrderedDict()
for item in arguments:
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/user/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[OrderedDict, Dict
# only keep the `ContractFunction` with the longest set of input args
continue
functions[node.name] = func
for node in base_node.get_children(vy_ast.VariableDecl, {"annotation.func.id": "public"}):
for node in base_node.get_children(vy_ast.VariableDecl, {"is_public": True}):
name = node.target.id
if name in functions:
raise NamespaceCollision(
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_type_from_annotation(
Arguments
---------
node : VyperNode
Vyper ast node from the `annotation` member of a `VariableDef` or `AnnAssign` node.
Vyper ast node from the `annotation` member of a `VariableDecl` or `AnnAssign` node.
Returns
-------
Expand Down
7 changes: 3 additions & 4 deletions vyper/semantics/validation/data_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout:

for node in vyper_module.get_children(vy_ast.VariableDecl):

if node.get("annotation.func.id") == "immutable":
# skip non-storage variables
if node.is_constant or node.is_immutable:
continue

type_ = node.target._metadata["type"]
Expand Down Expand Up @@ -210,9 +211,7 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict:

ret = {}
offset = 0
for node in vyper_module.get_children(
vy_ast.VariableDecl, filters={"annotation.func.id": "immutable"}
):
for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}):
type_ = node._metadata["type"]
type_.set_position(CodeOffset(offset))

Expand Down
13 changes: 3 additions & 10 deletions vyper/semantics/validation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import vyper.builtin_interfaces
from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
from vyper.exceptions import (
CallViolation,
CompilerPanic,
Expand Down Expand Up @@ -153,18 +152,12 @@ def visit_VariableDecl(self, node):
if name is None:
raise VariableDeclarationException("Invalid module-level assignment", node)

annotation = node.annotation
# remove the outer call node, to handle cases such as `public(map(..))`
if node.is_public or node.is_immutable or node.is_constant:
validate_call_args(annotation, 1)
annotation = annotation.args[0]

if node.is_public:
# generate function type and add to metadata
# we need this when builing the public getter
# we need this when building the public getter
node._metadata["func_type"] = ContractFunction.getter_from_VariableDecl(node)

elif node.is_immutable:
if node.is_immutable:
# mutability is checked automatically preventing assignment
# outside of the constructor, here we just check a value is assigned,
# not necessarily where
Expand All @@ -185,7 +178,7 @@ def visit_VariableDecl(self, node):

data_loc = DataLocation.CODE if node.is_immutable else DataLocation.STORAGE
type_definition = get_type_from_annotation(
annotation, data_loc, node.is_constant, node.is_public, node.is_immutable
node.annotation, data_loc, node.is_constant, node.is_public, node.is_immutable
)
node._metadata["type"] = type_definition

Expand Down

0 comments on commit be2b7f4

Please sign in to comment.