Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement edge distance on OntologyGraph. #51

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/hpotk/graph/_api.py
Original file line number Diff line number Diff line change
@@ -142,6 +142,26 @@ def is_descendant_of(self, sub: typing.Union[str, NODE, Identified],
"""
return self._run_query(self.get_descendants, sub, obj)

def compute_edge_distance(self, left: typing.Union[str, NODE, Identified],
right: typing.Union[str, NODE, Identified]) -> int:
"""
Calculate the edge distance as the number of edges in the shortest path between the graph nodes.

Distance of a node to itself is `0`.

:param left: a graph node.
:param right: other graph node.
:return: a non-negative `int` of the edge distance.
"""
left = OntologyGraph._map_to_term_id(left)
right = OntologyGraph._map_to_term_id(right)
if left == right:
return 0 # Distance to self is `0`.

left_dist = get_ancestor_distances(self, left)
right_dist = get_ancestor_distances(self, right)
return find_minimum_distance(left_dist, right_dist)

@staticmethod
def _run_query(func: typing.Callable[[NODE], typing.Iterator[NODE]],
sub: typing.Union[str, NODE, Identified],
@@ -170,6 +190,42 @@ def _map_to_term_id(item: typing.Union[str, NODE, Identified]) -> TermId:
raise ValueError(f'Expected `str`, `TermId` or `Identified` but got `{type(item)}`')


def find_minimum_distance(left_dist: typing.Mapping[TermId, int],
right_dist: typing.Mapping[TermId, int]) -> int:
dist = None
for shared in left_dist.keys() & right_dist.keys():
current = left_dist[shared] + right_dist[shared]
if dist is None:
dist = current
else:
dist = min(dist, current)

return dist


def get_ancestor_distances(graph, src: TermId) -> typing.Mapping[TermId, int]:
distances = {}
seen = set()

stack = [(0, src)]
while stack:
distance, term_id = stack.pop()

current = distance + 1
for parent in graph.get_parents(term_id):
if parent not in seen:
stack.append((current, parent))
seen.add(parent)

if term_id in distances:
# We must keep the shortest distance
distances[term_id] = min(distances[term_id], distance)
else:
distances[term_id] = distance

return distances


class GraphAware(typing.Generic[NODE], metaclass=abc.ABCMeta):
"""
A mixin class for entities that have an :class:`OntologyGraph`.
30 changes: 30 additions & 0 deletions src/hpotk/graph/_test__api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest
import typing
import unittest

import ddt
import numpy as np

from hpotk.model import TermId, Identified
from ._api import OntologyGraph
from ._csr_graph import BisectPoweredCsrOntologyGraph
from ._test_data import get_toy_graph
from .csr import ImmutableCsrMatrix


@@ -201,6 +204,33 @@ def test_traversal_methods_produce_iterators(self):
self.assertIsInstance(self.GRAPH.get_descendants(whatever), typing.Iterator)


class TestNeo:

@pytest.fixture
def toy_og(self) -> OntologyGraph:
_, og = get_toy_graph()
return og

@pytest.mark.parametrize('left, right, expected',
[
('HP:1', 'HP:1', 0),
('HP:01', 'HP:01', 0),
('HP:010', 'HP:0110', 1),
('HP:011', 'HP:0110', 1),
('HP:01', 'HP:0110', 2),
('HP:1', 'HP:0110', 3),
('HP:03', 'HP:01', 2),
('HP:03', 'HP:02', 2),
('HP:0110', 'HP:022', 5),
])
def test_compute_edge_distance(self, left: str, right: str, expected: int, toy_og: OntologyGraph):
left = TermId.from_curie(left)
right = TermId.from_curie(right)

assert toy_og.compute_edge_distance(left, right) == expected
assert toy_og.compute_edge_distance(right, left) == expected


class SimpleIdentified(Identified):

@staticmethod
26 changes: 26 additions & 0 deletions src/hpotk/graph/_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import typing

import numpy as np

from hpotk.model import TermId

from ._api import OntologyGraph

from ._csr_graph import BisectPoweredCsrOntologyGraph
from .csr import ImmutableCsrMatrix


def get_toy_graph() -> typing.Tuple[typing.Sequence[TermId], OntologyGraph]:
root = TermId.from_curie('HP:1')
curies = [
'HP:01', 'HP:010', 'HP:011', 'HP:0110',
'HP:02', 'HP:020', 'HP:021', 'HP:022',
'HP:03', 'HP:1'
]
nodes = np.fromiter(map(TermId.from_curie, curies), dtype=object)
row = [0, 3, 5, 7, 9, 13, 14, 15, 16, 17, 20]
col = [1, 2, 9, 0, 3, 0, 3, 1, 2, 5, 6, 7, 9, 4, 4, 4, 9, 0, 4, 8]
data = [-1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1]
am = ImmutableCsrMatrix(row, col, data, shape=(len(nodes), len(nodes)), dtype=int)

return nodes, BisectPoweredCsrOntologyGraph(root, nodes, am)
Loading