Skip to content

Commit

Permalink
feat: replace AnnAssign with VariableDecl (#2881)
Browse files Browse the repository at this point in the history
Add a dedicated AST type for contract variable declarations

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
tserg and charles-cooper authored Jul 18, 2022
1 parent f2623ba commit 1d1ef5d
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 87 deletions.
6 changes: 4 additions & 2 deletions tests/parser/ast_utils/test_ast_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ def test_basic_ast():
"node_id": 4,
"src": "4:6:0",
},
"ast_type": "AnnAssign",
"ast_type": "VariableDef",
"col_offset": 0,
"end_col_offset": 9,
"end_lineno": 2,
"lineno": 2,
"node_id": 1,
"simple": 1,
"src": "1:9:0",
"target": {
"ast_type": "Name",
Expand All @@ -71,6 +70,9 @@ def test_basic_ast():
"src": "1:1:0",
},
"value": None,
"is_constant": False,
"is_immutable": False,
"is_public": False,
}


Expand Down
3 changes: 3 additions & 0 deletions tests/parser/exceptions/test_syntax_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def foo():
def foo():
x: address = 0x123456789012345678901234567890123456789
""",
"""
a: internal(uint256)
""",
]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from vyper.exceptions import StructureException, UnknownType
from vyper.exceptions import StructureException, SyntaxException, UnknownType


def test_external_contract_call_declaration_expr(get_contract, assert_tx_failed):
Expand Down Expand Up @@ -234,7 +234,7 @@ def set_lucky(_lucky: int128): nonpayable
modifiable_bar_contract: trusted(Bar)
"""
assert_compile_failed(lambda: get_contract(code), UnknownType)
assert_compile_failed(lambda: get_contract(code), SyntaxException)


def test_invalid_if_have_modifiability_not_declared(
Expand Down
5 changes: 3 additions & 2 deletions tests/parser/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ArgumentException,
InvalidReference,
StructureException,
SyntaxException,
TypeMismatch,
UnknownAttribute,
)
Expand Down Expand Up @@ -36,7 +37,7 @@ def test():
a: address(ERC20) # invalid syntax now.
""",
StructureException,
SyntaxException,
),
(
"""
Expand All @@ -46,7 +47,7 @@ def test():
def test():
a: address(ERC20) = ZERO_ADDRESS
""",
StructureException,
(StructureException, SyntaxException),
),
(
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/syntax/utils/test_event_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import raises

from vyper import compiler
from vyper.exceptions import NamespaceCollision, StructureException, UnknownType
from vyper.exceptions import NamespaceCollision, StructureException, SyntaxException, UnknownType

fail_list = [ # noqa: E122
(
Expand Down Expand Up @@ -74,7 +74,7 @@ def foo(i: int128) -> int128:
"""
Transfer: eve.t({_from: indexed(address)})
""",
UnknownType,
SyntaxException,
),
(
"""
Expand Down
4 changes: 2 additions & 2 deletions vyper/ast/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
Top-level Vyper AST node.
"""

for node in vyper_module.get_children(vy_ast.AnnAssign, {"annotation.func.id": "public"}):
for node in vyper_module.get_children(vy_ast.VariableDef, {"annotation.func.id": "public"}):
func_type = node._metadata["func_type"]
input_types, return_type = func_type.get_signature()
input_nodes = []
Expand Down Expand Up @@ -100,7 +100,7 @@ def remove_unused_statements(vyper_module: vy_ast.Module) -> None:
"""

# constant declarations - values were substituted within the AST during folding
for node in vyper_module.get_children(vy_ast.AnnAssign, {"annotation.func.id": "constant"}):
for node in vyper_module.get_children(vy_ast.VariableDef, {"annotation.func.id": "constant"}):
vyper_module.remove_from_body(node)

# `implements: interface` statements - validated during type checking
Expand Down
6 changes: 4 additions & 2 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int:
"""
changed_nodes = 0

for node in vyper_module.get_children(vy_ast.AnnAssign):
for node in vyper_module.get_children(vy_ast.VariableDef):
if not isinstance(node.target, vy_ast.Name):
# left-hand-side of assignment is not a variable
continue
Expand Down Expand Up @@ -278,7 +278,9 @@ def replace_constant(

if not node.get_ancestor(vy_ast.Index):
# do not replace left-hand side of assignments
assign = node.get_ancestor((vy_ast.Assign, vy_ast.AnnAssign, vy_ast.AugAssign))
assign = node.get_ancestor(
(vy_ast.Assign, vy_ast.AnnAssign, vy_ast.AugAssign, vy_ast.VariableDef)
)

if assign and node in assign.target.get_descendants(include_self=True):
continue
Expand Down
60 changes: 60 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def get_node(
ast_struct = copy.copy(ast_struct)
del ast_struct["parent"]

# Replace state and local variable declarations `AnnAssign` with `VariableDef`
# Parent node is required for context to determine whether replacement should happen.
if (
ast_struct["ast_type"] == "AnnAssign"
and isinstance(parent, Module)
and not getattr(ast_struct["target"], "id", None) in ("implements",)
):
ast_struct["ast_type"] = "VariableDef"

vy_class = getattr(sys.modules[__name__], ast_struct["ast_type"], None)
if not vy_class:
if ast_struct["ast_type"] == "Delete":
Expand Down Expand Up @@ -524,6 +533,7 @@ def get_descendants(

def get(self, field_str: str) -> Any:
"""
Recursive getter function for node attributes.
Parameters
Expand Down Expand Up @@ -1238,6 +1248,56 @@ class AnnAssign(VyperNode):
__slots__ = ("target", "annotation", "value", "simple")


class VariableDef(VyperNode):
"""
A contract variable declaration.
Excludes `simple` attribute from Python `AnnAssign` node.
Attributes
----------
target : VyperNode
Left-hand side of the assignment.
value : VyperNode
Right-hand side of the assignment.
annotation : VyperNode
Type of variable.
is_constant : bool, optional
If true, indicates that the variable is a constant variable.
is_public : bool, optional
If true, indicates that the variable is a public state variable.
is_immutable : bool, optional
If true, indicates that the variable is an immutable variable.
"""

__slots__ = ("target", "annotation", "value", "is_constant", "is_public", "is_immutable")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.is_constant = False
self.is_public = False
self.is_immutable = False

if isinstance(self.annotation, Call):
# the annotation is a function call, e.g. `foo: constant(uint256)`
call_name = self.annotation.get("func.id")
if call_name == "constant":
# declaring a constant
self.is_constant = True

elif call_name == "public":
# declaring a public variable
self.is_public = True

elif call_name == "immutable":
# declaring an immutable variable
self.is_immutable = True

else:
_raise_syntax_exc("Invalid scope for variable declaration", self.annotation)


class AugAssign(VyperNode):
__slots__ = ("op", "target", "value")

Expand Down
8 changes: 8 additions & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ class AnnAssign(VyperNode):
value: VyperNode = ...
annotation: VyperNode = ...

class VariableDef(VyperNode):
target: Name = ...
value: VyperNode = ...
annotation: VyperNode = ...
is_constant: bool = ...
is_public: bool = ...
is_immutable: bool = ...

class AugAssign(VyperNode):
op: VyperNode = ...
target: VyperNode = ...
Expand Down
17 changes: 16 additions & 1 deletion vyper/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@


def parse_to_ast(
source_code: str, source_id: int = 0, contract_name: Optional[str] = None
source_code: str,
source_id: int = 0,
contract_name: Optional[str] = None,
add_fn_node: Optional[str] = None,
) -> vy_ast.Module:
"""
Parses a Vyper source string and generates basic Vyper AST nodes.
Expand All @@ -19,6 +22,10 @@ def parse_to_ast(
The Vyper source code to parse.
source_id : int, optional
Source id to use in the `src` member of each node.
contract_name: str, optional
Name of contract.
add_fn_node: str, optional
If not None, adds a dummy Python AST FunctionDef wrapper node.
Returns
-------
Expand All @@ -33,6 +40,14 @@ def parse_to_ast(
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), source_code, e.lineno, e.offset) from e

# Add dummy function node to ensure local variables are treated as `AnnAssign`
# instead of state variables (`VariableDef`)
if add_fn_node:
fn_node = python_ast.FunctionDef(add_fn_node, py_ast.body, [], [])
fn_node.body = py_ast.body
fn_node.args = python_ast.arguments(defaults=[])
py_ast.body = [fn_node]
annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name)

# Convert to Vyper AST.
Expand Down
14 changes: 5 additions & 9 deletions vyper/builtin_functions/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vyper import ast as vy_ast
from vyper.ast import parse_to_ast
from vyper.codegen.context import Context
from vyper.codegen.global_context import GlobalContext
Expand All @@ -15,28 +14,25 @@ def _strip_source_pos(ir_node):


def generate_inline_function(code, variables, variables_2, memory_allocator):
ast_code = parse_to_ast(code)
ast_code = parse_to_ast(code, add_fn_node="dummy_fn")
# Annotate the AST with a temporary old (i.e. typecheck) namespace
namespace = Namespace()
namespace.update(variables_2)
with override_global_namespace(namespace):
# Initialise a placeholder `FunctionDef` AST node and corresponding
# `ContractFunction` type to rely on the annotation visitors in semantics
# module.
fn_node = vy_ast.FunctionDef()
fn_node.body = []
fn_node.args = vy_ast.arguments(defaults=[])
fn_node._metadata["type"] = ContractFunction(
ast_code.body[0]._metadata["type"] = ContractFunction(
"sqrt_builtin", {}, 0, 0, None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE
)
sv = FunctionNodeVisitor(ast_code, fn_node, namespace)
for n in ast_code.body:
sv = FunctionNodeVisitor(ast_code, ast_code.body[0], namespace)
for n in ast_code.body[0].body:
sv.visit(n)

new_context = Context(
vars_=variables, global_ctx=GlobalContext(), memory_allocator=memory_allocator
)
generated_ir = parse_body(ast_code.body, new_context)
generated_ir = parse_body(ast_code.body[0].body, new_context)
# strip source position info from the generated_ir since
# it doesn't make any sense (e.g. the line numbers will start from 0
# instead of where we are in the code)
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_global_context(

# Statements of the form:
# variable_name: type
elif isinstance(item, vy_ast.AnnAssign):
elif isinstance(item, vy_ast.VariableDef):
global_ctx.add_globals_and_events(item)
# Function definitions
elif isinstance(item, vy_ast.FunctionDef):
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ def from_annotation(
is_immutable: bool = False,
) -> "BaseTypeDefinition":
"""
Generate a `BaseTypeDefinition` instance of this type from `AnnAssign.annotation`
Generate a `BaseTypeDefinition` instance of this type from `VariableDef.annotation`
or `AnnAssign.annotation`
Arguments
---------
node : VyperNode
Vyper ast node from the `annotation` member of an `AnnAssign` node.
Vyper ast node from the `annotation` member of a `VariableDef` or `AnnAssign` node.
Returns
-------
Expand Down
6 changes: 3 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,15 @@ def set_reentrancy_key_position(self, position: StorageSlot) -> None:
self.reentrancy_key_position = position

@classmethod
def from_AnnAssign(cls, node: vy_ast.AnnAssign) -> "ContractFunction":
def getter_from_VariableDef(cls, node: vy_ast.VariableDef) -> "ContractFunction":
"""
Generate a `ContractFunction` object from an `AnnAssign` node.
Generate a `ContractFunction` object from an `VariableDef` node.
Used to create getter functions for public variables.
Arguments
---------
node : AnnAssign
node : VariableDef
Vyper ast node to generate the function definition from.
Returns
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/user/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[OrderedDict, Dict
# only keep the `ContractFunction` with the longest set of input args
continue
functions[node.name] = func
for node in base_node.get_children(vy_ast.AnnAssign, {"annotation.func.id": "public"}):
for node in base_node.get_children(vy_ast.VariableDef, {"annotation.func.id": "public"}):
name = node.target.id
if name in functions:
raise NamespaceCollision(
f"Interface contains multiple functions named '{name}'", base_node
)
functions[name] = ContractFunction.from_AnnAssign(node)
functions[name] = ContractFunction.getter_from_VariableDef(node)
for node in base_node.get_children(vy_ast.EventDef):
name = node.name
if name in functions or name in events:
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_type_from_annotation(
Arguments
---------
node : VyperNode
Vyper ast node from the `annotation` member of an `AnnAssign` node.
Vyper ast node from the `annotation` member of a `VariableDef` or `AnnAssign` node.
Returns
-------
Expand Down
Loading

0 comments on commit 1d1ef5d

Please sign in to comment.