Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(common): add Node.find_below() methods to exclude the root node from filtering #8861

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 41 additions & 25 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ibis.common.collections import frozendict
from ibis.common.patterns import NoMatch, Pattern
from ibis.common.typing import _ClassInfo
from ibis.util import experimental
from ibis.util import experimental, promote_list

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -340,9 +340,39 @@ def find(
determined by a breadth-first search.

"""
nodes = Graph.from_bfs(self, filter=filter, context=context).nodes()
graph = Graph.from_bfs(self, filter=filter, context=context)
finder = _coerce_finder(finder, context)
return [node for node in nodes if finder(node)]
return [node for node in graph.nodes() if finder(node)]

@experimental
def find_below(
kszucs marked this conversation as resolved.
Show resolved Hide resolved
self,
finder: FinderLike,
filter: Optional[FinderLike] = None,
context: Optional[dict] = None,
) -> list[Node]:
"""Find all nodes below the current node matching a given pattern in the graph.

A variant of find() that only returns nodes below the current node in the graph.

Parameters
----------
finder
A type, tuple of types, a pattern or a callable to match upon.
filter
A type, tuple of types, a pattern or a callable to filter out nodes
from the traversal. The traversal will only visit nodes that match
the given filter and stop otherwise.
context
Optional context to use if `finder` or `filter` is a pattern.

Returns
-------
The list of nodes matching the given pattern.
"""
graph = Graph.from_bfs(self.__children__, filter=filter, context=context)
finder = _coerce_finder(finder, context)
return [node for node in graph.nodes() if finder(node)]

@experimental
def find_topmost(
Expand Down Expand Up @@ -620,10 +650,8 @@ def bfs(root: Node) -> Graph:
"""
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

queue = deque([root])
nodes = _flatten_collections(promote_list(root))
queue = deque(nodes)
graph = Graph()

while queue:
Expand Down Expand Up @@ -651,15 +679,10 @@ def bfs_while(root: Node, filter: Finder) -> Graph:
A graph constructed from the root node.

"""
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

queue = deque()
nodes = _flatten_collections(promote_list(root))
queue = deque(node for node in nodes if filter(node))
graph = Graph()

if filter(root):
queue.append(root)

while queue:
if (node := queue.popleft()) not in graph:
children = tuple(child for child in node.__children__ if filter(child))
Expand All @@ -684,10 +707,8 @@ def dfs(root: Node) -> Graph:
"""
# fast path for the default no filter case, according to benchmarks
# this is gives a 10% speedup compared to the filtered version
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

stack = deque([root])
nodes = _flatten_collections(promote_list(root))
stack = deque(nodes)
graph = {}

while stack:
Expand Down Expand Up @@ -715,15 +736,10 @@ def dfs_while(root: Node, filter: Finder) -> Graph:
A graph constructed from the root node.

"""
if not isinstance(root, Node):
raise TypeError("node must be an instance of ibis.common.graph.Node")

stack = deque()
nodes = _flatten_collections(promote_list(root))
stack = deque(node for node in nodes if filter(node))
graph = {}

if filter(root):
stack.append(root)

while stack:
if (node := stack.pop()) not in graph:
children = tuple(child for child in node.__children__ if filter(child))
Expand Down
24 changes: 14 additions & 10 deletions ibis/common/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ def copy(self, name=None, children=None):

def test_bfs():
assert list(bfs(A).keys()) == [A, B, C, D, E]

with pytest.raises(
TypeError, match="must be an instance of ibis.common.graph.Node"
):
bfs(1)
assert list(bfs([D, E, B])) == [D, E, B]
assert bfs(1) == {}


def test_construction():
Expand All @@ -82,11 +79,8 @@ def test_graph_repr():

def test_dfs():
assert list(dfs(A).keys()) == [D, E, B, C, A]

with pytest.raises(
TypeError, match="must be an instance of ibis.common.graph.Node"
):
dfs(1)
assert list(dfs([D, E, B])) == [D, E, B]
assert dfs(1) == {}


def test_invert():
Expand Down Expand Up @@ -393,6 +387,16 @@ def test_node_find_using_pattern():
assert result == [A, B]


def test_node_find_below():
lowercase = MyNode(name="lowercase", children=[])
root = MyNode(name="root", children=[A, B, lowercase])
result = root.find_below(MyNode)
assert result == [A, B, lowercase, C, D, E]

result = root.find_below(lambda x: x.name.islower(), filter=lambda x: x != root)
assert result == [lowercase]


def test_node_find_topmost_using_type():
class FooNode(MyNode):
pass
Expand Down
Loading