Skip to content

Commit

Permalink
Implement edge distance on OntologyGraph.
Browse files Browse the repository at this point in the history
  • Loading branch information
ielis committed Dec 15, 2023
1 parent d6d421d commit 678519f
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/hpotk/graph/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ def is_descendant_of(self, sub: typing.Union[str, NODE, Identified],
"""
return self._run_query(self.get_descendants, sub, obj)

def compute_edge_distance(self, left: typing.Union[str, NODE, Identified],
right: typing.Union[str, NODE, Identified]) -> int:
"""
Calculate the edge distance as the number of edges in the shortest path between the graph nodes.
Distance of a node to itself is `0`.
:param left: a graph node.
:param right: other graph node.
:return: a non-negative `int` of the edge distance.
"""
left = OntologyGraph._map_to_term_id(left)
right = OntologyGraph._map_to_term_id(right)
if left == right:
return 0 # Distance to self is `0`.

left_dist = get_ancestor_distances(self, left)
right_dist = get_ancestor_distances(self, right)
return find_minimum_distance(left_dist, right_dist)

@staticmethod
def _run_query(func: typing.Callable[[NODE], typing.Iterator[NODE]],
sub: typing.Union[str, NODE, Identified],
Expand Down Expand Up @@ -170,6 +190,42 @@ def _map_to_term_id(item: typing.Union[str, NODE, Identified]) -> TermId:
raise ValueError(f'Expected `str`, `TermId` or `Identified` but got `{type(item)}`')


def find_minimum_distance(left_dist: typing.Mapping[TermId, int],
right_dist: typing.Mapping[TermId, int]) -> int:
dist = None
for shared in left_dist.keys() & right_dist.keys():
current = left_dist[shared] + right_dist[shared]
if dist is None:
dist = current
else:
dist = min(dist, current)

return dist


def get_ancestor_distances(graph, src: TermId) -> typing.Mapping[TermId, int]:
distances = {}
seen = set()

stack = [(0, src)]
while stack:
distance, term_id = stack.pop()

current = distance + 1
for parent in graph.get_parents(term_id):
if parent not in seen:
stack.append((current, parent))
seen.add(parent)

if term_id in distances:
# We must keep the shortest distance
distances[term_id] = min(distances[term_id], distance)
else:
distances[term_id] = distance

return distances


class GraphAware(typing.Generic[NODE], metaclass=abc.ABCMeta):
"""
A mixin class for entities that have an :class:`OntologyGraph`.
Expand Down
30 changes: 30 additions & 0 deletions src/hpotk/graph/_test__api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest
import typing
import unittest

import ddt
import numpy as np

from hpotk.model import TermId, Identified
from ._api import OntologyGraph
from ._csr_graph import BisectPoweredCsrOntologyGraph
from ._test_data import get_toy_graph
from .csr import ImmutableCsrMatrix


Expand Down Expand Up @@ -201,6 +204,33 @@ def test_traversal_methods_produce_iterators(self):
self.assertIsInstance(self.GRAPH.get_descendants(whatever), typing.Iterator)


class TestNeo:

@pytest.fixture
def toy_og(self) -> OntologyGraph:
_, og = get_toy_graph()
return og

@pytest.mark.parametrize('left, right, expected',
[
('HP:1', 'HP:1', 0),
('HP:01', 'HP:01', 0),
('HP:010', 'HP:0110', 1),
('HP:011', 'HP:0110', 1),
('HP:01', 'HP:0110', 2),
('HP:1', 'HP:0110', 3),
('HP:03', 'HP:01', 2),
('HP:03', 'HP:02', 2),
('HP:0110', 'HP:022', 5),
])
def test_compute_edge_distance(self, left: str, right: str, expected: int, toy_og: OntologyGraph):
left = TermId.from_curie(left)
right = TermId.from_curie(right)

assert toy_og.compute_edge_distance(left, right) == expected
assert toy_og.compute_edge_distance(right, left) == expected


class SimpleIdentified(Identified):

@staticmethod
Expand Down
26 changes: 26 additions & 0 deletions src/hpotk/graph/_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

import numpy as np

from hpotk.model import TermId

from ._api import OntologyGraph

from ._csr_graph import BisectPoweredCsrOntologyGraph
from .csr import ImmutableCsrMatrix


def get_toy_graph() -> typing.Tuple[typing.Sequence[TermId], OntologyGraph]:
root = TermId.from_curie('HP:1')
curies = [
'HP:01', 'HP:010', 'HP:011', 'HP:0110',
'HP:02', 'HP:020', 'HP:021', 'HP:022',
'HP:03', 'HP:1'
]
nodes = np.fromiter(map(TermId.from_curie, curies), dtype=object)
row = [0, 3, 5, 7, 9, 13, 14, 15, 16, 17, 20]
col = [1, 2, 9, 0, 3, 0, 3, 1, 2, 5, 6, 7, 9, 4, 4, 4, 9, 0, 4, 8]
data = [-1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1]
am = ImmutableCsrMatrix(row, col, data, shape=(len(nodes), len(nodes)), dtype=int)

return nodes, BisectPoweredCsrOntologyGraph(root, nodes, am)

0 comments on commit 678519f

Please sign in to comment.