Skip to content

Commit

Permalink
Improve how a Visitor alters the original AST tree
Browse files Browse the repository at this point in the history
This is a less naïve way to support AST visitors that apply changes to
the tree, in particular handling the weird case when they alter/remove a
node contained in a (sub)list.

This hopefully fixes issue #107.
  • Loading branch information
lelit committed Jul 29, 2022
1 parent 1757c55 commit 954b3b6
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 32 deletions.
90 changes: 58 additions & 32 deletions pglast/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Ancestor:
.. testsetup:: *
from pglast import parse_sql
from pglast import ast, parse_sql
from pglast.visitors import Ancestor
.. doctest::
Expand Down Expand Up @@ -133,12 +133,20 @@ class Ancestor:
<class 'tuple'>
>>> generate_series_path.node[0][0]
<FuncCall funcname=(<String val='generate_series'>,)...
As an aid to visitors that apply changes to the AST, there are two methods, :meth:`update`
and :meth:`apply`, that takes care of the different cases, when `node` is an AST instance
or instead it's a tuple (or subtuple); the former does not directly change the AST tree,
but postpones that when the latter is called.
"""

__slots__ = ('parent', 'node', 'member', 'pending_update')

def __init__(self, parent=None, node=None, member=None):
self.parent = parent
self.node = node
self.member = member
self.pending_update = None

def __iter__(self):
"Iterate over each step, yielding either an attribute name or a sequence index."
Expand Down Expand Up @@ -195,6 +203,43 @@ def __matmul__(self, root):
node = getattr(node, member)
return node

def update(self, new_value):
"Set `new_value` as a pending change to the tracked node."

if isinstance(self.member, int):
# We are pointing to one particular item in a tuple, so we need to build a new one,
# replacing it with the new value and then update the parent node accordingly:
# since there may be further changes, temporarily use a list instead of a tuple, at
# apply() time it will be coerced back to a tuple. Note that there are "list of
# lists" cases, that are handled at apply time.
if self.parent.pending_update is None:
self.parent.pending_update = list(self.node)
self.parent.pending_update[self.member] = new_value
return self.parent
else:
self.pending_update = new_value
return self

def apply(self):
"Apply the pending change, if any, to the actual node."

if self.pending_update is not None:
if isinstance(self.pending_update, list):
value = tuple(filter(lambda item: item is not Delete, self.pending_update))
else:
value = self.pending_update
if isinstance(self.member, int):
pvalue = list(self.node)
pvalue[self.member] = value or None
self.parent.update(pvalue)
self.parent.apply()
else:
if self.member is None:
self.node = value or None
else:
setattr(self.node, self.member, value or None)
self.pending_update = None


class Visitor:
"""Base class implementing the `visitor pattern`__.
Expand Down Expand Up @@ -265,6 +310,8 @@ def iterate(self, node):
chain* as it finds them while traversing the tree.
"""

pending_updates = []

todo = deque()

if isinstance(node, (tuple, ast.Node)):
Expand All @@ -275,14 +322,15 @@ def iterate(self, node):
while todo:
ancestors, node = todo.popleft()

# Here `node` may be either one AST node, a tuple of AST nodes (e.g.
# SelectStmt.targetList), or even a tuple of tuples of AST nodes (e.g.
# SelectStmt.valuesList). To simplify code, coerce it to a sequence.

is_sequence = isinstance(node, tuple)
if is_sequence:
nodes = list(node)
new_nodes = []
else:
nodes = [node]
new_nodes = None
sequence_changed = False

index = 0
while nodes:
Expand All @@ -294,47 +342,25 @@ def iterate(self, node):
if isinstance(sub_node, ast.Node):
action = yield sub_ancestors, sub_node
if action is Continue:
if is_sequence:
new_nodes.append(sub_node)

for member in sub_node:
value = getattr(sub_node, member)
if isinstance(value, (tuple, ast.Node)):
todo.append((sub_ancestors / (sub_node, member), value))
elif action is Skip:
if is_sequence:
new_nodes.append(sub_node)
pass
else:
if action is Delete:
if is_sequence:
sequence_changed = True
new_node = None
elif action is not sub_node:
if is_sequence:
sequence_changed = True
new_nodes.append(action)
else:
new_node = action

if not is_sequence:
parent = ancestors[0]
if parent is not None:
setattr(parent, ancestors.member, new_node)
else:
self.root = new_node
pending_updates.append(sub_ancestors.update(action))
elif isinstance(sub_node, tuple):
for sub_index, value in enumerate(sub_node):
if isinstance(value, (tuple, ast.Node)):
todo.append((sub_ancestors / (sub_node, sub_index), value))

index += 1

if is_sequence and sequence_changed:
parent = ancestors[0]
if parent is not None:
setattr(parent, ancestors.member, tuple(new_nodes) if new_nodes else None)
else:
self.root = tuple(new_nodes) if new_nodes else None
for pending_update in pending_updates:
pending_update.apply()
if pending_update.member is None:
self.root = pending_update.node

visit = None
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,25 @@ def visit_Constraint(self, ancestors, node):
DropNullConstraint()(raw)
assert RawStream()(raw) == 'CREATE TABLE foo (a integer CHECK (a <> 0))'

class DeleteOddsInList(visitors.Visitor):
def visit_A_Const(self, ancestors, node):
if isinstance(node.val, ast.Integer):
if node.val.val % 2:
return visitors.Delete

raw = parse_sql('INSERT INTO foo VALUES (1, 2, 3, 42, 43)')
DeleteOddsInList()(raw)
assert RawStream()(raw) == 'INSERT INTO foo VALUES (2, 42)'

raw = parse_sql('INSERT INTO foo VALUES ((1, 2, 3, 42, 43),'
' (2, 1, 4, 3, 5))')
DeleteOddsInList()(raw)
assert RawStream()(raw) == 'INSERT INTO foo VALUES ((2, 42), (2, 4))'

raw = parse_sql('select true from foo where a in (1, 2, 3)')
DeleteOddsInList()(raw)
assert RawStream()(raw) == 'SELECT TRUE FROM foo WHERE a IN (2)'


def test_alter_node():
class AddNullConstraint(visitors.Visitor):
Expand Down Expand Up @@ -266,6 +285,14 @@ def visit_Integer(self, ancestors, node):
DoubleAllIntegers()(raw)
assert RawStream()(raw) == 'SELECT 42'

class ReplaceConstantInList(visitors.Visitor):
def visit_A_Const(self, ancestors, node):
return ast.A_Const(val=ast.Integer(0))

raw = parse_sql('INSERT INTO foo VALUES (42)')
ReplaceConstantInList()(raw)
assert RawStream()(raw) == 'INSERT INTO foo VALUES (0)'


def test_replace_root_node():
class AndNowForSomethingCompletelyDifferent(visitors.Visitor):
Expand Down

0 comments on commit 954b3b6

Please sign in to comment.