Skip to content

Commit

Permalink
fix type annotation in folding; lint
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Sep 23, 2023
1 parent 250e479 commit 8d77b05
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 51 deletions.
5 changes: 0 additions & 5 deletions vyper/ast/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ def remove_unused_statements(vyper_module: vy_ast.Module) -> None:
vyper_module : Module
Top-level Vyper AST node.
"""

# constant declarations - values were substituted within the AST during folding
#for node in vyper_module.get_children(vy_ast.VariableDecl, {"is_constant": True}):
# vyper_module.remove_from_body(node)

# `implements: interface` statements - validated during type checking
for node in vyper_module.get_children(vy_ast.ImplementsDecl):
vyper_module.remove_from_body(node)
8 changes: 3 additions & 5 deletions vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def replace_literal_ops(vyper_module: vy_ast.Module) -> int:
try:
new_node = node.evaluate()
typ = node._metadata.get("type")
typ.validate_literal(new_node)
new_node._metadata["type"] = typ
if typ:
typ.validate_literal(new_node)
new_node._metadata["type"] = typ
except UnfoldableNode:
continue

Expand Down Expand Up @@ -115,7 +116,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
try:
new_node = func.evaluate(node) # type: ignore
new_node._metadata["type"] = func.fetch_call_return(node)
print("replaced builtin fn: ", node.func.id)
except UnfoldableNode:
continue

Expand All @@ -140,11 +140,9 @@ def replace_user_defined_constants(vyper_module: vy_ast.Module) -> int:
int
Number of nodes that were replaced.
"""
print("replace user defined constants")
changed_nodes = 0

for node in vyper_module.get_children(vy_ast.VariableDecl):
print("node: ", node)
if not isinstance(node.target, vy_ast.Name):
# left-hand-side of assignment is not a variable
continue
Expand Down
44 changes: 21 additions & 23 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,21 +921,17 @@ class Name(ExprNode):
__slots__ = ("id",)

def derive(self, constants: dict):
try:
val = constants[self.id]
return val
except:
return None
return constants.get(self.id, None)


class UnaryOp(ExprNode):
__slots__ = ("op", "operand")

def derive(self, constants: dict):
try:
return self.op._op(self.operand.derive(constants))
except:
operand = self.operand.derive(constants)
if operand is None:
return None
return self.op._op(operand)

def evaluate(self) -> ExprNode:
"""
Expand Down Expand Up @@ -986,11 +982,11 @@ class BinOp(ExprNode):
__slots__ = ("left", "op", "right")

def derive(self, constants: dict):
try:
left, right = self.left, self.right
return self.op._op(left.derive(constants), right.derive(constants))
except:
left = self.left.derive(constants)
right = self.right.derive(constants)
if left is None or right is None:
return None
return self.op._op(left, right)

def evaluate(self) -> ExprNode:
"""
Expand Down Expand Up @@ -1143,11 +1139,10 @@ class BoolOp(ExprNode):
__slots__ = ("op", "values")

def derive(self, constants: dict):
try:
values = [i.derive(constants) for i in self.values]
return self.op._op(values)
except:
values = [i.derive(constants) for i in self.values]
if any(v is None for v in values):
return None
return self.op._op(values)

def evaluate(self) -> ExprNode:
"""
Expand Down Expand Up @@ -1206,15 +1201,18 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def derive(self, constants: dict):
try:
left, right = self.left, self.right
if isinstance(self.op, (In, NotIn)):
value = self.op._op(left.derive(constants), [i.derive(constants) for i in right.elements])
return value
left = self.left.derive(constants)

if isinstance(self.op, (In, NotIn)):
right = [i.derive(constants) for i in self.right.elements]
if left is None or any(v is None for v in right):
return None
return self.op._op(left, right)

return self.op._op(left.derive(constants), right.derive(constants))
except:
right = self.right.derive(constants)
if left is None or right is None:
return None
return self.op._op(left, right)

def evaluate(self) -> ExprNode:
"""
Expand Down
1 change: 0 additions & 1 deletion vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ def _signed_to_unsigned_comparision_op(op):

def parse_Compare(self):
left = Expr.parse_value_expr(self.expr.left, self.context)
print("parse_Compare - right: ", self.expr.right)
right = Expr.parse_value_expr(self.expr.right, self.context)

if right.value is None:
Expand Down
2 changes: 1 addition & 1 deletion vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def generate_folded_ast(
vyper_module_folded = copy.deepcopy(vyper_module)
vy_ast.folding.fold(vyper_module_folded)

#for node in vyper_module_folded.get_children(vy_ast.VariableDecl, {"is_constant": True}):
# for node in vyper_module_folded.get_children(vy_ast.VariableDecl, {"is_constant": True}):
# vyper_module.remove_from_body(node)
symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides)

Expand Down
6 changes: 4 additions & 2 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,9 @@ def visit_For(self, node):
)

right_val = args[1].right.derive(self.namespace._constants)
if not isinstance(args[1].right, vy_ast.Int) and not (isinstance(args[1].right, vy_ast.Name) and right_val):
if not isinstance(args[1].right, vy_ast.Int) and not (
isinstance(args[1].right, vy_ast.Name) and right_val
):
raise InvalidLiteral("Literal must be an integer", args[1].right)
if right_val < 1:
raise StructureException(
Expand Down Expand Up @@ -669,7 +671,7 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
for arg, arg_type in zip(node.args, call_type.arg_types):
self.visit(arg, arg_type)
else:
# Skip annotation of builtin functions that are always folded
# Skip annotation of builtin functions that are always folded
# because they will be annotated during folding.
if getattr(call_type, "_is_folded", False):
return
Expand Down
23 changes: 11 additions & 12 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,22 @@ def __init__(

# TODO: Move computation out of constructor
module_nodes = module_node.body.copy()
const_var_decls = [n for n in module_nodes if isinstance(n, vy_ast.VariableDecl) and n.is_constant]
const_var_decls = [
n for n in module_nodes if isinstance(n, vy_ast.VariableDecl) and n.is_constant
]

while const_var_decls:
derived_nodes = 0

for c in const_var_decls:
try:
name = c.get("target.id")
val = c.value.derive(self.namespace._constants)
self.namespace.add_constant(name, val)

if val:
derived_nodes += 1
const_var_decls.remove(c)
except:
pass

name = c.get("target.id")
val = c.value.derive(self.namespace._constants)
self.namespace.add_constant(name, val)

if val is not None:
derived_nodes += 1
const_var_decls.remove(c)

if not derived_nodes:
break

Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def validate_assignment(self, attr):
def add_constant(self, name, value):
self._constants[name] = value


def get_namespace():
"""
Get the active namespace object.
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def get_subscripted_type(self, node):
return self.value_type

@classmethod
def from_annotation(cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript], constants: dict) -> "HashMapT":
def from_annotation(
cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript], constants: dict
) -> "HashMapT":
if (
not isinstance(node, vy_ast.Subscript)
or not isinstance(node.slice, vy_ast.Index)
Expand Down Expand Up @@ -275,7 +277,6 @@ def compare_type(self, other):
@classmethod
def from_annotation(cls, node: vy_ast.Subscript, constants: dict) -> "DArrayT":
max_length = node.slice.value.elements[1].derive(constants)
print("max length: ", max_length)
if (
not isinstance(node, vy_ast.Subscript)
or not isinstance(node.slice, vy_ast.Index)
Expand Down

0 comments on commit 8d77b05

Please sign in to comment.