diff --git a/arguebuf/model/graph.py b/arguebuf/model/graph.py index e94b9bb..7c41342 100644 --- a/arguebuf/model/graph.py +++ b/arguebuf/model/graph.py @@ -141,22 +141,65 @@ def outgoing_atom_nodes( child_nodes = incoming_nodes parent_nodes = outgoing_nodes - def sibling_nodes( - self, node: t.Union[str, AbstractNode] - ) -> t.AbstractSet[AbstractNode]: + def sibling_node_distances( + self, + node: t.Union[str, AbstractNode], + max_levels: t.Optional[int] = None, + node_type: t.Type[AbstractNode] = AbstractNode, + ) -> dict[AbstractNode, int]: + """Find all sibling nodes of a node and their distance in the graph""" + if isinstance(node, str): node = self._nodes[node] - parent_nodes = self.parent_nodes(node) - sibling_nodes = set() + # visited: set[AbstractNode] = set() + sibling_nodes: dict[AbstractNode, int] = {} + parent_nodes: dict[AbstractNode, int] = { + parent: 1 for parent in self.parent_nodes(node) + } + + while parent_nodes: + parent, level = parent_nodes.popitem() + child_nodes = self.child_nodes(parent) + grandparent_nodes = self.parent_nodes(parent) + + parent_nodes.update( + {parent: level + 1 for parent in grandparent_nodes} + if (max_levels is None or level < max_levels) + and len(grandparent_nodes) > 0 + else {} + ) - for parent_node in parent_nodes: - sibling_nodes.update(self.child_nodes(parent_node)) + for _ in range(level - 1): + if len(child_nodes) == 0: + break - sibling_nodes.remove(node) + next_child_nodes: set[AbstractNode] = set() + + for child_node in child_nodes: + next_child_nodes.update(self.child_nodes(child_node)) + + child_nodes = next_child_nodes + + sibling_nodes.update( + { + child_node: level + for child_node in child_nodes + if isinstance(child_node, node_type) + and child_node not in sibling_nodes + } + ) return sibling_nodes + def sibling_nodes( + self, + node: t.Union[str, AbstractNode], + max_levels: t.Optional[int] = 1, + node_type: t.Type[AbstractNode] = AbstractNode, + ) -> t.AbstractSet[AbstractNode]: + return self.sibling_node_distances(node, max_levels, node_type).keys() + def incoming_edges(self, node: t.Union[str, AbstractNode]) -> t.AbstractSet[Edge]: if isinstance(node, str): node = self._nodes[node] diff --git a/tests/test_create_graph.py b/tests/test_create_graph.py index 4bb3d05..757021f 100644 --- a/tests/test_create_graph.py +++ b/tests/test_create_graph.py @@ -41,6 +41,39 @@ def test_remove_branch(): assert len(g.edges) == 0 +def test_sibling_nodes(): + g = ag.Graph() + + a1 = ag.AtomNode("", id="a1") + a2 = ag.AtomNode("", id="a2") + a3 = ag.AtomNode("", id="a3") + a4 = ag.AtomNode("", id="a4") + a5 = ag.AtomNode("", id="a5") + a6 = ag.AtomNode("", id="a6") + s1 = ag.SchemeNode(id="s1") + s2 = ag.SchemeNode(id="s2") + s3 = ag.SchemeNode(id="s3") + s4 = ag.SchemeNode(id="s4") + s5 = ag.SchemeNode(id="s5") + + g.add_edge(ag.Edge(a1, s1)) + g.add_edge(ag.Edge(a2, s2)) + g.add_edge(ag.Edge(s1, a3)) + g.add_edge(ag.Edge(s2, a3)) + g.add_edge(ag.Edge(a3, s3)) + g.add_edge(ag.Edge(s3, a4)) + g.add_edge(ag.Edge(s4, a4)) + g.add_edge(ag.Edge(a5, s4)) + g.add_edge(ag.Edge(s5, a5)) + g.add_edge(ag.Edge(a6, s5)) + + siblings = g.sibling_node_distances(a2) + + assert len(siblings) == 3 + assert siblings[a1] == 2 + assert siblings[a6] == 4 + + def test_create_graph(tmp_path: Path): g = ag.Graph("Test")