From 625b2375744e12eef358a564ac1d3afeae3e07a6 Mon Sep 17 00:00:00 2001 From: Ugur Bayindir Date: Wed, 2 Aug 2023 09:46:24 +0100 Subject: [PATCH] Refactored cell_type_dict initialization --- .../graph_generator/graph_generator.py | 79 +++++++++++-------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/pandasaurus_cxg/graph_generator/graph_generator.py b/pandasaurus_cxg/graph_generator/graph_generator.py index d1c06f2..e83ef1b 100644 --- a/pandasaurus_cxg/graph_generator/graph_generator.py +++ b/pandasaurus_cxg/graph_generator/graph_generator.py @@ -40,13 +40,16 @@ def __init__( raise MissingEnrichmentProcess(enrichment_methods) else: self.enriched_df = enricher.enriched_df - self.cell_type_dict = ( - enricher.get_anndata() - .obs[["cell_type_ontology_term_id", "cell_type"]] + self.cell_type_dict = { + **enricher.enriched_df[["s", "s_label"]] .drop_duplicates() - .set_index("cell_type_ontology_term_id")["cell_type"] - .to_dict() - ) + .set_index("s")["s_label"] + .to_dict(), + **enricher.enriched_df[["o", "o_label"]] + .drop_duplicates() + .set_index("o")["o_label"] + .to_dict(), + } self.ns = Namespace("http://example.org/") self.graph = Graph() @@ -143,7 +146,7 @@ def enrich_rdf_graph(self): def save_rdf_graph( self, - graph: Optional[Graph], + graph: Optional[Graph] = None, file_name: Optional[str] = "mygraph", _format: Optional[str] = "xml", ): @@ -178,7 +181,7 @@ def save_rdf_graph( def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: str): # TODO visualize all graph, with parametric annotation properties to better visualize the nodes. - # TODO apply redundancy striping to owl directly + # TODO better handle format parameter graph = Graph().parse(file_path, format="ttl") if file_path else self.graph if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"): raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph") @@ -204,19 +207,21 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: ): if not isinstance(next_node, BNode): stack.append(next_node) - # else: - # subgraph.add( - # ( - # node, - # next(graph.objects(next_node, OWL.onProperty)), - # next(graph.objects(next_node, OWL.someValuesFrom)), - # ) - # ) + else: + _p = next(graph.objects(next_node, OWL.onProperty)) + _o = next(graph.objects(next_node, OWL.someValuesFrom)) + subgraph.add( + ( + node, + _p, + _o, + ) + ) + stack.append(_o) # TODO not sure if we need this else or not if not start_node: for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)): - # Add all outgoing edges of the current node subgraph.add((s, p, o)) nx_graph = nx.DiGraph() @@ -228,6 +233,9 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: str(obj).split("/")[-1], **edge_data, ) + elif predicate != RDF.type: + nx_graph.add_node(str(subject).split("/")[-1], label=str(obj)) + # TODO not sure if we need this else or not, related with previous else # Apply transitive reduction to remove redundancy transitive_reduction_graph = nx.transitive_reduction(nx_graph) @@ -235,10 +243,8 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: transitive_reduction_graph.add_edges_from( (u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges ) - # Layout the graph as a hierarchical tree pos = nx.drawing.nx_agraph.graphviz_layout(transitive_reduction_graph, prog="dot") - # Plot the graph as a hierarchical tree node_labels = nx.get_node_attributes(transitive_reduction_graph, "label") plt.figure(figsize=(10, 10)) @@ -265,24 +271,31 @@ def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: plt.show() def transitive_reduction(self, predicate: str, file_path: str, _format: str = "xml"): + # TODO better handle format parameter graph = Graph().parse(file_path, format="ttl") if file_path else self.graph - if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"): + # SPARQL query to list all object properties in the RDF graph + query = """ + SELECT DISTINCT ?property + WHERE { + ?subject ?property ?object . + FILTER (isIRI(?object)) + } + """ + predicate_list = [str(r["property"]) for r in graph.query(query)] + if predicate and not graph.query(f"ASK {{ ?s <{predicate}> ?o }}"): raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph") subgraph = Graph() - for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)): - # Add all outgoing edges of the current node + for s, p, o in graph.triples((None, URIRef(predicate) if predicate else None, None)): subgraph.add((s, p, o)) nx_graph = nx.DiGraph() - for subject, _predicate, obj in subgraph: - if isinstance(obj, URIRef) and _predicate != RDF.type: - edge_data = { - "label": "is_a" if _predicate == RDF.type else str(predicate).split("/")[-1] - } + for s, p, o in subgraph: + if isinstance(o, URIRef) and p != RDF.type: + edge_data = {"label": str(predicate).split("/")[-1]} nx_graph.add_edge( - str(subject).split("/")[-1], - str(obj).split("/")[-1], + s, + o, **edge_data, ) @@ -291,15 +304,11 @@ def transitive_reduction(self, predicate: str, file_path: str, _format: str = "x transitive_reduction_graph.add_edges_from( (u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges ) - # Remove redundant triples using nx graph edge_diff = list(set(nx_graph.edges) - set(transitive_reduction_graph.edges)) - for edge in edge_diff: - if graph.query( - f"ASK {{ {self.ns[edge[0]].n3()} {self.ns[predicate].n3()} {self.ns[edge[1]].n3()} }}" - ): - graph.remove((self.ns[edge[0]], self.ns[predicate], self.ns[edge[1]])) + if graph.query(f"ASK {{ <{edge[0]}> <{predicate}> <{edge[1]}> }}"): + graph.remove((edge[0], URIRef(predicate), edge[1])) self.save_rdf_graph(graph, f"{file_path.split('.')[0]}_non_redundant", _format)