Skip to content

Commit

Permalink
Refactored cell_type_dict initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ubyndr committed Aug 2, 2023
1 parent 6f8d4fa commit 625b237
Showing 1 changed file with 44 additions and 35 deletions.
79 changes: 44 additions & 35 deletions pandasaurus_cxg/graph_generator/graph_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ def __init__(
raise MissingEnrichmentProcess(enrichment_methods)
else:
self.enriched_df = enricher.enriched_df
self.cell_type_dict = (
enricher.get_anndata()
.obs[["cell_type_ontology_term_id", "cell_type"]]
self.cell_type_dict = {
**enricher.enriched_df[["s", "s_label"]]
.drop_duplicates()
.set_index("cell_type_ontology_term_id")["cell_type"]
.to_dict()
)
.set_index("s")["s_label"]
.to_dict(),
**enricher.enriched_df[["o", "o_label"]]
.drop_duplicates()
.set_index("o")["o_label"]
.to_dict(),
}
self.ns = Namespace("http://example.org/")
self.graph = Graph()

Expand Down Expand Up @@ -143,7 +146,7 @@ def enrich_rdf_graph(self):

def save_rdf_graph(
self,
graph: Optional[Graph],
graph: Optional[Graph] = None,
file_name: Optional[str] = "mygraph",
_format: Optional[str] = "xml",
):
Expand Down Expand Up @@ -178,7 +181,7 @@ def save_rdf_graph(

def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: str):
# TODO visualize all graph, with parametric annotation properties to better visualize the nodes.
# TODO apply redundancy striping to owl directly
# TODO better handle format parameter
graph = Graph().parse(file_path, format="ttl") if file_path else self.graph
if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"):
raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph")
Expand All @@ -204,19 +207,21 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path:
):
if not isinstance(next_node, BNode):
stack.append(next_node)
# else:
# subgraph.add(
# (
# node,
# next(graph.objects(next_node, OWL.onProperty)),
# next(graph.objects(next_node, OWL.someValuesFrom)),
# )
# )
else:
_p = next(graph.objects(next_node, OWL.onProperty))
_o = next(graph.objects(next_node, OWL.someValuesFrom))
subgraph.add(
(
node,
_p,
_o,
)
)
stack.append(_o)
# TODO not sure if we need this else or not

if not start_node:
for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)):
# Add all outgoing edges of the current node
subgraph.add((s, p, o))

nx_graph = nx.DiGraph()
Expand All @@ -228,17 +233,18 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path:
str(obj).split("/")[-1],
**edge_data,
)
elif predicate != RDF.type:
nx_graph.add_node(str(subject).split("/")[-1], label=str(obj))
# TODO not sure if we need this else or not, related with previous else

# Apply transitive reduction to remove redundancy
transitive_reduction_graph = nx.transitive_reduction(nx_graph)
transitive_reduction_graph.add_nodes_from(nx_graph.nodes(data=True))
transitive_reduction_graph.add_edges_from(
(u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges
)

# Layout the graph as a hierarchical tree
pos = nx.drawing.nx_agraph.graphviz_layout(transitive_reduction_graph, prog="dot")

# Plot the graph as a hierarchical tree
node_labels = nx.get_node_attributes(transitive_reduction_graph, "label")
plt.figure(figsize=(10, 10))
Expand All @@ -265,24 +271,31 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path:
plt.show()

def transitive_reduction(self, predicate: str, file_path: str, _format: str = "xml"):
# TODO better handle format parameter
graph = Graph().parse(file_path, format="ttl") if file_path else self.graph
if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"):
# SPARQL query to list all object properties in the RDF graph
query = """
SELECT DISTINCT ?property
WHERE {
?subject ?property ?object .
FILTER (isIRI(?object))
}
"""
predicate_list = [str(r["property"]) for r in graph.query(query)]
if predicate and not graph.query(f"ASK {{ ?s <{predicate}> ?o }}"):
raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph")

subgraph = Graph()
for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)):
# Add all outgoing edges of the current node
for s, p, o in graph.triples((None, URIRef(predicate) if predicate else None, None)):
subgraph.add((s, p, o))

nx_graph = nx.DiGraph()
for subject, _predicate, obj in subgraph:
if isinstance(obj, URIRef) and _predicate != RDF.type:
edge_data = {
"label": "is_a" if _predicate == RDF.type else str(predicate).split("/")[-1]
}
for s, p, o in subgraph:
if isinstance(o, URIRef) and p != RDF.type:
edge_data = {"label": str(predicate).split("/")[-1]}
nx_graph.add_edge(
str(subject).split("/")[-1],
str(obj).split("/")[-1],
s,
o,
**edge_data,
)

Expand All @@ -291,15 +304,11 @@ def transitive_reduction(self, predicate: str, file_path: str, _format: str = "x
transitive_reduction_graph.add_edges_from(
(u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges
)

# Remove redundant triples using nx graph
edge_diff = list(set(nx_graph.edges) - set(transitive_reduction_graph.edges))

for edge in edge_diff:
if graph.query(
f"ASK {{ {self.ns[edge[0]].n3()} {self.ns[predicate].n3()} {self.ns[edge[1]].n3()} }}"
):
graph.remove((self.ns[edge[0]], self.ns[predicate], self.ns[edge[1]]))
if graph.query(f"ASK {{ <{edge[0]}> <{predicate}> <{edge[1]}> }}"):
graph.remove((edge[0], URIRef(predicate), edge[1]))

self.save_rdf_graph(graph, f"{file_path.split('.')[0]}_non_redundant", _format)

Expand Down

0 comments on commit 625b237

Please sign in to comment.