Skip to content

Commit

Permalink
Add visit() function to source tree
Browse files Browse the repository at this point in the history
The new visit() function:
- Passes each node as an argument to the specified callable.
- Decides which node to visit next based on the return type.

The return types are an enum; the values in this enum may need to be
extended in order to support additional functionality.

Signed-off-by: John Pennycook <john.pennycook@intel.com>
  • Loading branch information
Pennycook committed Oct 10, 2024
1 parent 20d11b4 commit 491bff7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 35 deletions.
36 changes: 35 additions & 1 deletion codebasin/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import hashlib
import logging
import os
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy
from enum import Enum
from typing import Self

import numpy as np
Expand Down Expand Up @@ -539,6 +540,11 @@ class ParseError(ValueError):
"""


class Visit(Enum):
NEXT = (0,)
NEXT_SIBLING = (1,)


class Node:
"""
Base class for all other Node types.
Expand Down Expand Up @@ -597,6 +603,22 @@ def walk(self) -> Iterable[Self]:
for child in self.children:
yield from child.walk()

def visit(self, visitor: Callable[[Self], Visit]):
"""
Visit all descendants of this node via a preorder traversal, using the
supplied visitor.
Raises
------
TypeError
If `visitor` is not callable.
"""
if not callable(visitor):
raise TypeError("visitor is not callable.")
if visitor(self) != Visit.NEXT_SIBLING:
for child in self.children:
child.visit(visitor)


class FileNode(Node):
"""
Expand Down Expand Up @@ -2359,6 +2381,18 @@ def walk(self) -> Iterable[Node]:
"""
yield from self.root.walk()

def visit(self, visitor: Callable[[Node], Visit]):
"""
Visit each node in the tree via a preorder traversal, using the
supplied visitor.
Raises
------
TypeError
If `visitor` is not callable.
"""
self.root.visit(visitor)

def associate_file(self, filename):
self.root.filename = filename

Expand Down
123 changes: 89 additions & 34 deletions tests/source-tree/test_source_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from codebasin.file_parser import FileParser
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode, Visit


class TestSourceTree(unittest.TestCase):
Expand All @@ -19,9 +19,6 @@ def setUp(self):
logging.getLogger("codebasin").disabled = False
warnings.simplefilter("ignore", ResourceWarning)

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""

# TODO: Revisit this when SourceTree can be built without a file.
with tempfile.NamedTemporaryFile(
mode="w",
Expand All @@ -43,36 +40,94 @@ def test_walk(self):
f.close()

# TODO: Revisit this when __str__() is more reliable.
tree = FileParser(f.name).parse_file(summarize_only=False)
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
f.name,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)
self.tree = FileParser(f.name).parse_file(summarize_only=False)
self.filename = f.name

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
self.filename,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(self.tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)

def test_visit_types(self):
"""Check that visit() validates inputs"""

class valid_visitor:
def __call__(self, node):
return True

self.tree.visit(valid_visitor())

def visitor_function(node):
return True

self.tree.visit(visitor_function)

with self.assertRaises(TypeError):
self.tree.visit(1)

class invalid_visitor:
pass

with self.assertRaises(TypeError):
self.tree.visit(invalid_visitor())

def test_visit(self):
"""Check that visit() visits nodes as expected"""

# Check that a trivial visitor visits all nodes.
class NodeCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
self.count += 1

node_counter = NodeCounter()
self.tree.visit(node_counter)
self.assertEqual(node_counter.count, 9)

# Check that returning NEXT_SIBLING prevents descent.
class TopLevelCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
if not isinstance(node, FileNode):
self.count += 1
if isinstance(node, DirectiveNode):
return Visit.NEXT_SIBLING
return Visit.NEXT

top_level_counter = TopLevelCounter()
self.tree.visit(top_level_counter)
self.assertEqual(top_level_counter.count, 5)


if __name__ == "__main__":
Expand Down

0 comments on commit 491bff7

Please sign in to comment.