Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Jan 3, 2024
1 parent 495e2fa commit 02bdf27
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 252 deletions.
14 changes: 6 additions & 8 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,10 @@ def get_folded_value(self) -> "VyperNode":
For constant/literal nodes, the node should be directly returned
without caching to the metadata.
"""
if self.is_literal_value:
return self

if "folded_value" not in self._metadata:
res = self._try_fold() # possibly throws UnfoldableNode
self._set_folded_value(res)

return self._metadata["folded_value"]
try:
return self._metadata["folded_value"]
except KeyError:
raise UnfoldableNode("not foldable", self)

def _set_folded_value(self, node: "VyperNode") -> None:
# sanity check this is only called once
Expand Down Expand Up @@ -1089,6 +1085,7 @@ class RShift(Operator):
class BoolOp(ExprNode):
__slots__ = ("op", "values")


class And(Operator):
__slots__ = ()
_description = "logical and"
Expand Down Expand Up @@ -1193,6 +1190,7 @@ class Attribute(ExprNode):
class Subscript(ExprNode):
__slots__ = ("slice", "value")


class Index(VyperNode):
__slots__ = ("value",)

Expand Down
1 change: 0 additions & 1 deletion vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class VyperNode:
@classmethod
def get_fields(cls: Any) -> set: ...
def get_folded_value(self) -> VyperNode: ...
def _try_fold(self) -> VyperNode: ...
def _set_folded_value(self, node: VyperNode) -> None: ...
@classmethod
def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ...
Expand Down
160 changes: 160 additions & 0 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from vyper import ast as vy_ast
from vyper.exceptions import InvalidLiteral, UndeclaredDefinition, UnfoldableNode
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.namespace import get_namespace


class ConstantFolder(VyperNodeVisitorBase):
def visit(self, node):
for c in node.get_children():
try:
self.visit(c)
except UnfoldableNode:
# ignore bubbled up exceptions
pass

try:
for class_ in node.__class__.mro():
ast_type = class_.__name__

visitor_fn = getattr(self, f"visit_{ast_type}", None)
if visitor_fn:
folded_value = visitor_fn(node)
node._set_folded_value(folded_value)
return folded_value
else:
raise UnfoldableNode
except UnfoldableNode:
# ignore bubbled up exceptions
pass

def visit_Constant(self, node) -> vy_ast.ExprNode:
return node

def visit_Name(self, node) -> vy_ast.ExprNode:
namespace = get_namespace()
try:
ret = namespace[node]
except UndeclaredDefinition:
raise UnfoldableNode("unknown name", node)

if not isinstance(ret, vy_ast.VariableDecl) and not ret.is_constant:
raise UnfoldableNode("not a constant", node)

return ret.value.get_folded_value()

def visit_UnaryOp(self, node):
operand = node.operand.get_folded_value()

if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant):
raise UnfoldableNode("not a boolean!", node.operand)
if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num):
raise UnfoldableNode("not a number!", node.operand)
if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int):
raise UnfoldableNode("not an int!", node.operand)

value = node.op._op(operand.value)
return type(operand).from_node(node, value=value)

def visit_BinOp(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if type(left) is not type(right):
raise UnfoldableNode("invalid operation", node)
if not isinstance(left, vy_ast.Num):
raise UnfoldableNode("not a number!", node.left)

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", node.right)

value = node.op._op(left.value, right.value)
return type(left).from_node(node, value=value)

def visit_BoolOp(self, node):
values = [v.get_folded_value() for v in node.values]

if any(not isinstance(v, vy_ast.NameConstant) for v in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

values = [v.value for v in values]
value = node.op._op(values)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_Compare(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if not isinstance(left, vy_ast.Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

# CMC 2022-08-04 we could probably remove these evaluation rules as they
# are taken care of in the IR optimizer now.
if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
if not isinstance(right, vy_ast.List):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if len(set([type(i) for i in right.elements])) > 1:
raise UnfoldableNode("List contains multiple literal types")
value = node.op._op(left.value, [i.value for i in right.elements])
return vy_ast.NameConstant.from_node(node, value=value)

if not isinstance(left, type(right)):
raise UnfoldableNode("Cannot compare different literal types")

# this is maybe just handled in the type checker.
if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num):
raise UnfoldableNode(
f"Invalid literal types for {node.op.description} comparison", node
)

value = node.op._op(left.value, right.value)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_List(self, node) -> vy_ast.ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Tuple(self, node) -> vy_ast.ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Dict(self, node) -> vy_ast.ExprNode:
values = [v.get_folded_value() for v in node.values]
return type(node).from_node(node, values=values)

def visit_Call(self, node) -> vy_ast.ExprNode:
if not isinstance(node.func, vy_ast.Name):
raise UnfoldableNode("not a builtin", node)

namespace = get_namespace()

func_name = node.func.id
if func_name not in namespace:
raise UnfoldableNode("unknown", node)

typ = namespace[func_name]
# TODO: rename to vyper_type.try_fold_call_expr
if not hasattr(typ, "_try_fold"):
raise UnfoldableNode("unfoldable", node)
return typ._try_fold(node)

def visit_Subscript(self, node) -> vy_ast.ExprNode:
slice_ = node.slice.value.get_folded_value()
value = node.value.get_folded_value()

if not isinstance(value, vy_ast.List):
raise UnfoldableNode("Subscript object is not a literal list")

elements = value.elements
if len(set([type(i) for i in elements])) > 1:
raise UnfoldableNode("List contains multiple node types")

if not isinstance(slice_, vy_ast.Int):
raise UnfoldableNode("invalid index type", slice_)

idx = slice_.value
if idx < 0 or idx >= len(elements):
raise UnfoldableNode("invalid index value")

return elements[idx]
124 changes: 0 additions & 124 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,130 +529,6 @@ def visit_Return(self, node):
self.expr_visitor.visit(node.value, self.func.return_type)


class ConstantFolder(VyperNodeVisitorBase):
scope_name = "function"

def visit(self, node):
folded_value = super().visit(node)

if folded_value is not None:
node._set_folded_value(folded_value)

def visit_Constant(self, node) -> ExprNode:
return node

def visit_UnaryOp(self, node) -> ExprNode:
operand = node.operand.get_folded_value()

if isinstance(node.op, Not) and not isinstance(operand, NameConstant):
raise UnfoldableNode("not a boolean!", node.operand)
if isinstance(node.op, USub) and not isinstance(operand, Num):
raise UnfoldableNode("not a number!", node.operand)
if isinstance(node.op, Invert) and not isinstance(operand, Int):
raise UnfoldableNode("not an int!", node.operand)

value = node.op._op(operand.value)
return type(operand).from_node(node, value=value)

def visit_BinOp(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if type(left) is not type(right):
raise UnfoldableNode("invalid operation", node)
if not isinstance(left, Num):
raise UnfoldableNode("not a number!", node.left)

# this validation is performed to prevent the compiler from hanging
# on very large shifts and improve the error message for negative
# values.
if isinstance(node.op, (LShift, RShift)) and not (0 <= right.value <= 256):
raise InvalidLiteral("Shift bits must be between 0 and 256", node.right)

value = node.op._op(left.value, right.value)
return type(left).from_node(node, value=value)

def visit_BoolOp(self, node):
values = [v.get_folded_value() for v in node.values]

if any(not isinstance(v, NameConstant) for v in values):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

values = [v.value for v in values]
value = node.op._op(values)
return NameConstant.from_node(node, value=value)

def visit_Compare(self, node):
left, right = [i.get_folded_value() for i in (node.left, node.right)]
if not isinstance(left, Constant):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")

# CMC 2022-08-04 we could probably remove these evaluation rules as they
# are taken care of in the IR optimizer now.
if isinstance(node.op, (In, NotIn)):
if not isinstance(right, List):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if next((i for i in right.elements if not isinstance(i, Constant)), None):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
if len(set([type(i) for i in right.elements])) > 1:
raise UnfoldableNode("List contains multiple literal types")
value = node.op._op(left.value, [i.value for i in right.elements])
return NameConstant.from_node(node, value=value)

if not isinstance(left, type(right)):
raise UnfoldableNode("Cannot compare different literal types")

if not isinstance(node.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)):
raise TypeMismatch(f"Invalid literal types for {node.op.description} comparison", node)

value = node.op._op(left.value, right.value)
return NameConstant.from_node(node, value=value)

def visit_List(self, node) -> ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Tuple(self, node) -> ExprNode:
elements = [e.get_folded_value() for e in node.elements]
return type(node).from_node(node, elements=elements)

def visit_Dict(self, node) -> ExprNode:
values = [v.get_folded_value() for v in self.values]
return type(node).from_node(node, values=values)

def visit_Call(self, node) -> ExprNode:
if not isinstance(node.func, Name):
raise UnfoldableNode("not a builtin", node)

# cursed import cycle!
from vyper.builtins.functions import DISPATCH_TABLE

func_name = node.func.id
if func_name not in DISPATCH_TABLE:
raise UnfoldableNode("not a builtin", node)

builtin_t = DISPATCH_TABLE[func_name]
return builtin_t._try_fold(node)

def visit_Subscript(self, node) -> ExprNode:
slice_ = self.slice.value.get_folded_value()
value = self.value.get_folded_value()

if not isinstance(value, List):
raise UnfoldableNode("Subscript object is not a literal list")

elements = value.elements
if len(set([type(i) for i in elements])) > 1:
raise UnfoldableNode("List contains multiple node types")

if not isinstance(slice_, Int):
raise UnfoldableNode("invalid index type", slice_)

idx = slice_.value
if idx < 0 or idx >= len(elements):
raise UnfoldableNode("invalid index value")

return elements[idx]


class ExprVisitor(VyperNodeVisitorBase):
scope_name = "function"

Expand Down
11 changes: 0 additions & 11 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, validate_functions
from vyper.semantics.analysis.pre_typecheck import pre_typecheck
from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace
Expand Down Expand Up @@ -138,8 +137,6 @@ def analyze(self) -> ModuleT:
self.visit(node)
to_visit.remove(node)

self.constants = {}

# keep trying to process all the nodes until we finish or can
# no longer progress. this makes it so we don't need to
# calculate a dependency tree between top-level items.
Expand All @@ -162,9 +159,6 @@ def analyze(self) -> ModuleT:
if count == len(to_visit):
err_list.raise_if_not_empty()

for n in self.ast.get_descendants(reverse=True):
_fold_with_constants(n, self.constants)

self.module_t = ModuleT(self.ast)
self.ast._metadata["type"] = self.module_t

Expand Down Expand Up @@ -314,11 +308,6 @@ def _validate_self_namespace():
if node.is_constant:
assert node.value is not None # checked in VariableDecl.validate()

for n in node.get_descendants(reverse=True):
_fold_with_constants(n, self.constants)

node.value.get_folded_value()

ExprVisitor().visit(node.value, type_) # performs validate_expected_type

if not check_modifiability(node.value, Modifiability.CONSTANT):
Expand Down
Loading

0 comments on commit 02bdf27

Please sign in to comment.