Skip to content

Commit

Permalink
fix: rewrite typechecker journal to handle nested commits (#3375)
Browse files Browse the repository at this point in the history
this commit fixes a bug which was introduced in 66930fd. when the
typechecker enters a nested loop, it can typecheck the inner loop
(committing it), and result in invalid state if the outer loop fails to
typecheck. this commit implements a checkpointing system in the
typechecker state committer so that it can handle nested changes. it
also cleans up the data structure used so that there is a single entry
point and users of the data structure do not need to think about
implementation details.

one drawback of this approach is that it *only* handles changes to the
metadata dict. non-idempotent changes to the AST during typechecking
(such as the use case we are using them for - caching) should then be
restricted to changes to the metadata dict so that they can register
automatically with the node metadata journal.
  • Loading branch information
charles-cooper authored May 4, 2023
1 parent 0c7066c commit 7c3cf61
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 37 deletions.
43 changes: 42 additions & 1 deletion tests/functional/semantics/analysis/test_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vyper.ast import parse_to_ast
from vyper.exceptions import ImmutableViolation
from vyper.exceptions import ImmutableViolation, TypeMismatch
from vyper.semantics.analysis import validate_semantics


Expand Down Expand Up @@ -99,3 +99,44 @@ def baz():
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation):
validate_semantics(vyper_module, {})


iterator_inference_codes = [
"""
@external
def main():
for j in range(3):
x: uint256 = j
y: uint16 = j
""", # issue 3212
"""
@external
def foo():
for i in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
"""
@external
def foo():
for i in [1]:
for j in [1]:
a:uint256 = i
b:uint16 = i
""", # issue 3374
"""
@external
def foo():
for i in [1,2,3]:
for j in [1,2,3]:
b:uint256 = j + i
c:uint16 = i
""", # issue 3374
]


@pytest.mark.parametrize("code", iterator_inference_codes)
def test_iterator_type_inference_checker(namespace, code):
vyper_module = parse_to_ast(code)
with pytest.raises(TypeMismatch):
validate_semantics(vyper_module, {})
80 changes: 80 additions & 0 deletions vyper/ast/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import contextlib
from typing import Any

from vyper.exceptions import VyperException


# a commit/rollback scheme for metadata caching. in the case that an
# exception is thrown and caught during type checking (currently, only
# during for loop iterator variable type inference), we can roll back
# any state updates due to type checking.
# this is implemented as a stack of changesets, because we need to
# handle nested rollbacks in the case of nested for loops
class _NodeMetadataJournal:
_NOT_FOUND = object()

def __init__(self):
self._node_updates: list[dict[tuple[int, str, Any], NodeMetadata]] = []

def register_update(self, metadata, k):
prev = metadata.get(k, self._NOT_FOUND)
self._node_updates[-1][(id(metadata), k)] = (metadata, prev)

@contextlib.contextmanager
def enter(self):
self._node_updates.append({})
try:
yield
except VyperException as e:
# note: would be better to only catch typechecker exceptions here.
self._rollback_inner()
raise e from e
else:
self._commit_inner()

def _rollback_inner(self):
for (_, k), (metadata, prev) in self._node_updates[-1].items():
if prev is self._NOT_FOUND:
metadata.pop(k, None)
else:
metadata[k] = prev
self._pop_inner()

def _commit_inner(self):
inner = self._pop_inner()

if len(self._node_updates) == 0:
return

outer = self._node_updates[-1]

# register with previous frame in case inner gets commited
# but outer needs to be rolled back
for (_, k), (metadata, prev) in inner.items():
if (id(metadata), k) not in outer:
outer[(id(metadata), k)] = (metadata, prev)

def _pop_inner(self):
return self._node_updates.pop()


class NodeMetadata(dict):
"""
A data structure which allows for journaling.
"""

_JOURNAL: _NodeMetadataJournal = _NodeMetadataJournal()

def __setitem__(self, k, v):
# if we are in a context where we need to journal, add
# this to the changeset.
if len(self._JOURNAL._node_updates) != 0:
self._JOURNAL.register_update(self, k)

super().__setitem__(k, v)

@classmethod
@contextlib.contextmanager
def enter_typechecker_speculation(cls):
with cls._JOURNAL.enter():
yield
3 changes: 2 additions & 1 deletion vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from typing import Any, Optional, Union

from vyper.ast.metadata import NodeMetadata
from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS
from vyper.exceptions import (
ArgumentException,
Expand Down Expand Up @@ -254,7 +255,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
"""
self.set_parent(parent)
self._children: set = set()
self._metadata: dict = {}
self._metadata: NodeMetadata = NodeMetadata()

for field_name in NODE_SRC_ATTRIBUTES:
# when a source offset is not available, use the parent's source offset
Expand Down
3 changes: 0 additions & 3 deletions vyper/semantics/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,3 @@ def validate_semantics(vyper_ast, interface_codes):
with namespace.enter_scope():
add_module_namespace(vyper_ast, interface_codes)
validate_functions(vyper_ast)

# clean up. not sure if this is necessary, but do it for hygiene's sake.
_ExprAnalyser._reset_taint()
11 changes: 4 additions & 7 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from vyper import ast as vy_ast
from vyper.ast.metadata import NodeMetadata
from vyper.ast.validation import validate_call_args
from vyper.exceptions import (
ExceptionList,
Expand All @@ -21,7 +22,6 @@
from vyper.semantics.analysis.base import DataLocation, VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.utils import (
_ExprAnalyser,
get_common_types,
get_exact_type_from_node,
get_expr_info,
Expand Down Expand Up @@ -453,20 +453,17 @@ def visit_For(self, node):
raise exc.with_annotation(node) from None

try:
for n in node.body:
self.visit(n)
with NodeMetadata.enter_typechecker_speculation():
for n in node.body:
self.visit(n)
except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)
# rollback any changes to the tree
_ExprAnalyser._rollback_taint()
else:
# type information is applied directly here because the
# scope is closed prior to the call to
# `StatementAnnotationVisitor`
node.target._metadata["type"] = type_

# perf - persist all calculated types
_ExprAnalyser._commit_taint()
# success -- bail out instead of error handling.
return

Expand Down
25 changes: 0 additions & 25 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ class _ExprAnalyser:
class's method resolution order is examined to decide which method to call.
"""

# this allows for a very simple commit/rollback scheme for metadata
# caching. in the case that an exception is thrown and caught during
# type checking (currently, only during for loop iterator variable
# type inference), we can roll back any state updates due to type
# checking.
_tainted_nodes: set[tuple[vy_ast.VyperNode, str]] = set()

def __init__(self):
self.namespace = get_namespace()

Expand Down Expand Up @@ -171,27 +164,9 @@ def get_possible_types_from_node(self, node, include_type_exprs=False):
ret.sort(key=lambda k: (k.bits, not k.is_signed), reverse=True)

node._metadata[k] = ret
# register with list of tainted nodes, in case the cache
# needs to be invalidated in case of a state rollback
self._tainted_nodes.add((node, k))

return node._metadata[k].copy()

@classmethod
def _rollback_taint(cls):
for node, k in cls._tainted_nodes:
node._metadata.pop(k, None)
# taint has been rolled back, no need to track it anymore
cls._reset_taint()

@classmethod
def _commit_taint(cls):
cls._reset_taint()

@classmethod
def _reset_taint(cls):
cls._tainted_nodes.clear()

def _find_fn(self, node):
# look for a type-check method for each class in the given class mro
for name in [i.__name__ for i in type(node).mro()]:
Expand Down

0 comments on commit 7c3cf61

Please sign in to comment.