Skip to content

Commit

Permalink
feat: add method to determine sibling distances
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Nov 2, 2023
1 parent 4aeddfc commit a3538de
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 8 deletions.
59 changes: 51 additions & 8 deletions arguebuf/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,65 @@ def outgoing_atom_nodes(
child_nodes = incoming_nodes
parent_nodes = outgoing_nodes

def sibling_nodes(
self, node: t.Union[str, AbstractNode]
) -> t.AbstractSet[AbstractNode]:
def sibling_node_distances(
self,
node: t.Union[str, AbstractNode],
max_levels: t.Optional[int] = None,
node_type: t.Type[AbstractNode] = AbstractNode,
) -> dict[AbstractNode, int]:
"""Find all sibling nodes of a node and their distance in the graph"""

if isinstance(node, str):
node = self._nodes[node]

parent_nodes = self.parent_nodes(node)
sibling_nodes = set()
# visited: set[AbstractNode] = set()
sibling_nodes: dict[AbstractNode, int] = {}
parent_nodes: dict[AbstractNode, int] = {
parent: 1 for parent in self.parent_nodes(node)
}

while parent_nodes:
parent, level = parent_nodes.popitem()
child_nodes = self.child_nodes(parent)
grandparent_nodes = self.parent_nodes(parent)

parent_nodes.update(
{parent: level + 1 for parent in grandparent_nodes}
if (max_levels is None or level < max_levels)
and len(grandparent_nodes) > 0
else {}
)

for parent_node in parent_nodes:
sibling_nodes.update(self.child_nodes(parent_node))
for _ in range(level - 1):
if len(child_nodes) == 0:
break

sibling_nodes.remove(node)
next_child_nodes: set[AbstractNode] = set()

for child_node in child_nodes:
next_child_nodes.update(self.child_nodes(child_node))

child_nodes = next_child_nodes

sibling_nodes.update(
{
child_node: level
for child_node in child_nodes
if isinstance(child_node, node_type)
and child_node not in sibling_nodes
}
)

return sibling_nodes

def sibling_nodes(
self,
node: t.Union[str, AbstractNode],
max_levels: t.Optional[int] = 1,
node_type: t.Type[AbstractNode] = AbstractNode,
) -> t.AbstractSet[AbstractNode]:
return self.sibling_node_distances(node, max_levels, node_type).keys()

def incoming_edges(self, node: t.Union[str, AbstractNode]) -> t.AbstractSet[Edge]:
if isinstance(node, str):
node = self._nodes[node]
Expand Down
33 changes: 33 additions & 0 deletions tests/test_create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ def test_remove_branch():
assert len(g.edges) == 0


def test_sibling_nodes():
g = ag.Graph()

a1 = ag.AtomNode("", id="a1")
a2 = ag.AtomNode("", id="a2")
a3 = ag.AtomNode("", id="a3")
a4 = ag.AtomNode("", id="a4")
a5 = ag.AtomNode("", id="a5")
a6 = ag.AtomNode("", id="a6")
s1 = ag.SchemeNode(id="s1")
s2 = ag.SchemeNode(id="s2")
s3 = ag.SchemeNode(id="s3")
s4 = ag.SchemeNode(id="s4")
s5 = ag.SchemeNode(id="s5")

g.add_edge(ag.Edge(a1, s1))
g.add_edge(ag.Edge(a2, s2))
g.add_edge(ag.Edge(s1, a3))
g.add_edge(ag.Edge(s2, a3))
g.add_edge(ag.Edge(a3, s3))
g.add_edge(ag.Edge(s3, a4))
g.add_edge(ag.Edge(s4, a4))
g.add_edge(ag.Edge(a5, s4))
g.add_edge(ag.Edge(s5, a5))
g.add_edge(ag.Edge(a6, s5))

siblings = g.sibling_node_distances(a2)

assert len(siblings) == 3
assert siblings[a1] == 2
assert siblings[a6] == 4


def test_create_graph(tmp_path: Path):
g = ag.Graph("Test")

Expand Down

0 comments on commit a3538de

Please sign in to comment.