Skip to content

Commit

Permalink
fix some small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Jan 7, 2024
1 parent d73caa7 commit a13086f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
14 changes: 8 additions & 6 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from vyper import ast as vy_ast
from vyper.exceptions import InvalidLiteral, UndeclaredDefinition, UnfoldableNode
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.namespace import get_namespace


Expand All @@ -13,6 +14,9 @@ def visit(self, node):
# ignore bubbled up exceptions
pass

if node.has_folded_value:
return node.get_folded_value()

try:
for class_ in node.__class__.mro():
ast_type = class_.__name__
Expand All @@ -22,26 +26,24 @@ def visit(self, node):
folded_value = visitor_fn(node)
node._set_folded_value(folded_value)
return folded_value
else:
raise UnfoldableNode
except UnfoldableNode:
# ignore bubbled up exceptions
pass
return node

def visit_Constant(self, node) -> vy_ast.ExprNode:
return node

def visit_Name(self, node) -> vy_ast.ExprNode:
namespace = get_namespace()
try:
ret = namespace[node]
varinfo = namespace[node.id]
except UndeclaredDefinition:
raise UnfoldableNode("unknown name", node)

if not isinstance(ret, vy_ast.VariableDecl) and not ret.is_constant:
if not isinstance(varinfo, VarInfo) or not varinfo.is_constant:
raise UnfoldableNode("not a constant", node)

return ret.value.get_folded_value()
return varinfo.decl_node.value.get_folded_value()

def visit_UnaryOp(self, node):
operand = node.operand.get_folded_value()
Expand Down
6 changes: 6 additions & 0 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import vyper.builtins.interfaces
from vyper import ast as vy_ast
from vyper.semantics.analysis.constant_folding import ConstantFolder
from vyper.ast.validation import validate_literal_nodes
from vyper.compiler.input_bundle import ABIInput, FileInput, FilesystemInputBundle, InputBundle
from vyper.evm.opcodes import version_check
Expand Down Expand Up @@ -159,6 +160,9 @@ def analyze(self) -> ModuleT:
if count == len(to_visit):
err_list.raise_if_not_empty()

# run constant folding recursively on all nodes
ConstantFolder().visit(self.ast)

self.module_t = ModuleT(self.ast)
self.ast._metadata["type"] = self.module_t

Expand Down Expand Up @@ -308,6 +312,8 @@ def _validate_self_namespace():
if node.is_constant:
assert node.value is not None # checked in VariableDecl.validate()

ConstantFolder().visit(node.value)

ExprVisitor().visit(node.value, type_) # performs validate_expected_type

if not check_modifiability(node.value, Modifiability.CONSTANT):
Expand Down

0 comments on commit a13086f

Please sign in to comment.