From 786c68191e007f5fd23dd2de25b6e1af6775e65c Mon Sep 17 00:00:00 2001 From: Ismail Ugur Bayindir Date: Tue, 1 Aug 2023 08:40:03 +0100 Subject: [PATCH] Added transitive_reduction method (#29) * Merged from main * Updated anndata_analyzer.py * Removed state and state.l2 from free-text annotations * Refactored visualize_rdf_graph method * Refactored save_rdf_graph, visualize_rdf_graph method and added transitive_reduction method * Format changes in co_annotation_report * Added state and state.l2 to free-text annotations --- .gitignore | 3 +- pandasaurus_cxg/anndata_analyzer.py | 150 ++++++++++-------- .../graph_generator/graph_generator.py | 133 ++++++++++++---- 3 files changed, 188 insertions(+), 98 deletions(-) diff --git a/.gitignore b/.gitignore index 760d150..51c4f3a 100644 --- a/.gitignore +++ b/.gitignore @@ -133,8 +133,7 @@ dmypy.json *.ttl *.nt -# pycharm/mac1 +# pycharm/mac .DS_Store .idea/ .DS_Store -.DS_Store \ No newline at end of file diff --git a/pandasaurus_cxg/anndata_analyzer.py b/pandasaurus_cxg/anndata_analyzer.py index 03cfe4d..e419ae2 100644 --- a/pandasaurus_cxg/anndata_analyzer.py +++ b/pandasaurus_cxg/anndata_analyzer.py @@ -1,9 +1,11 @@ +import itertools import os from enum import Enum -from typing import List +from typing import List, Optional import pandas as pd +from pandasaurus_cxg.anndata_enricher import AnndataEnricher from pandasaurus_cxg.anndata_loader import AnndataLoader from pandasaurus_cxg.schema.schema_loader import read_json_file @@ -20,7 +22,7 @@ class AnndataAnalyzer: schema_path (str): The path to the schema file. Attributes: - _anndata_obs (pd.DataFrame): The observation data from the AnnData object. + _anndata (pd.DataFrame): The observation data from the AnnData object. _schema (dict): The schema data loaded from the schema file. """ @@ -34,17 +36,21 @@ def __init__(self, file_path: str, schema_path: str): schema_path (str): The path to the schema file. """ - self._anndata_obs = AnndataLoader.load_from_file(file_path).obs + self.file_path = file_path + self._anndata = AnndataLoader.load_from_file(file_path) self._schema = read_json_file(schema_path) - def co_annotation_report(self): + def co_annotation_report(self, disease: Optional[str] = None, enrich: bool = False): """ Generates a co-annotation report based on the provided schema. - Examples: - | subclass.l3, dPT, cluster_matches, subclass.full, Degenerative Proximal Tubule Epithelial Cell - | subclass.l3, aTAL1, subcluster_of, subclass.full, Adaptive / Maladaptive / Repairing Thick Ascending Limb Cell - | class, epithelial cells, cluster_matches, cell_type, kidney collecting duct intercalated cell + Args: + disease (Optional[str]): A valid disease CURIE used to filter the rows based on the + given disease. If provided, only the rows matching the specified disease will be + included in the filtering process. Defaults to None if no disease filtering is + desired. + enrich: Flag to either enable or disable enrichment in co_annotation report. + Defaults to False. Returns: pd.DataFrame: The co-annotation report. @@ -56,86 +62,106 @@ def co_annotation_report(self): for field_name_1 in free_text_cell_type: if ( field_name_1 != field_name_2 - and field_name_1 in self._anndata_obs.columns - and field_name_2 in self._anndata_obs.columns + and field_name_1 in self._anndata.obs.columns + and field_name_2 in self._anndata.obs.columns ): - co_oc = ( - self._anndata_obs[[field_name_1, field_name_2]] - .drop_duplicates() - .reset_index(drop=True) - ) - field_name_2_dict = ( - co_oc.groupby(field_name_2)[field_name_1].apply(list).to_dict() - ) - field_name_1_dict = ( - co_oc.groupby(field_name_1)[field_name_2].apply(list).to_dict() - ) - co_oc["predicate"] = co_oc.apply( - self._assign_predicate, - args=( - field_name_1, - field_name_2, - field_name_1_dict, - field_name_2_dict, - debug_mode, - ), - axis=1, + co_oc = self._filter_data_and_drop_duplicates( + field_name_1, field_name_2, disease ) + if enrich: + co_oc = self._enrich_co_annotation(co_oc, field_name_1, field_name_2) + + AnndataAnalyzer._assign_predicate_column(co_oc, field_name_1, field_name_2) temp_result.extend(co_oc.to_dict(orient="records")) result = [ [item for sublist in [[k, v] for k, v in record.items()] for item in sublist] for record in temp_result ] - unique_result = self._remove_duplicates(result) + unique_result = AnndataAnalyzer._remove_duplicates(result) return pd.DataFrame( [inner_list[:2] + inner_list[5:6] + inner_list[2:4] for inner_list in unique_result], columns=["field_name1", "value1", "predicate", "field_name2", "value2"], ) + def enriched_co_annotation_report(self, disease: Optional[str] = None): + """ + Generates an enriched co-annotation report based on the provided schema. The enrichment + process will be performed by checking if any of the CL terms in the initial seed + (the set of CL terms used to initialize the Pandasaurus object) are also present in the + object column of the enrichment table. If a match is found, the co-annotation analysis + will be repeated, including everything that maps to this term, either directly or via + the enrichment table. + + Args: + disease (Optional[str]): A valid disease CURIE used to filter the rows based on the + given disease. If provided, only the rows matching the specified disease will be + included in the filtering process. Defaults to None if no disease filtering is + desired. + + Returns: + pd.DataFrame: The co-annotation report. + + """ + return self.co_annotation_report(self, disease, True) + + def _enrich_co_annotation(self, co_oc, field_name_1, field_name_2): + enricher = AnndataEnricher(self._anndata) + simple = enricher.simple_enrichment() + df = simple[simple["o"].isin(enricher.get_seed_list())][["s_label", "o_label"]].rename( + columns={"s_label": field_name_1, "o_label": field_name_2} + ) + co_oc = pd.concat([co_oc, df], axis=0).reset_index(drop=True) + return co_oc + + def _filter_data_and_drop_duplicates(self, field_name_1, field_name_2, disease): + # Filter the data based on the disease condition + co_oc = ( + self._anndata.obs[ + (self._anndata.obs["disease_ontology_term_id"].str.lower() == disease.lower()) + ][[field_name_1, field_name_2]] + if disease + else self._anndata.obs[[field_name_1, field_name_2]] + ) + # Drop duplicates + co_oc = co_oc.drop_duplicates().reset_index(drop=True) + return co_oc + @staticmethod def _remove_duplicates(data: List[List[str]]): - # TODO do a clean up if it is necessary + # TODO do a clean up/rename if it is necessary + # Currently used only to clean up supercluster_of relations unique_data = [] - unique_set = set() for sublist in data: if Predicate.SUPERCLUSTER_OF.value in sublist: continue unique_data.append(sublist) - # sorted_sublist = tuple(sorted(set(sublist))) - # if sorted_sublist not in unique_set: - # unique_data.append(sublist) - # unique_set.add(sorted_sublist) return unique_data @staticmethod - def _assign_predicate( - row: dict, - field_name_1: str, - field_name_2: str, - field_name_1_dict: dict, - field_name_2_dict: dict, - debug: bool, - ) -> str: - """ - Assigns a predicate based on the values of two fields in a row and dictionaries of field values. - - Used to determine if a cluster matches with the other cluster, or if it is its subclass, or if they overlap with each other. - - Args: - row (dict): The row containing the fields. - field_name_1 (str): The name of the first field. - field_name_2 (str): The name of the second field. - field_name_1_dict (dict): A dictionary mapping field_name_1 values to associated values. - field_name_2_dict (dict): A dictionary mapping field_name_2 values to associated values. - debug (bool): Whether to print debugging information. - - Returns: - str: The assigned predicate. + def _assign_predicate_column(co_oc, field_name_1, field_name_2): + # Group by field_name_2 and field_name_1 to create dictionaries + field_name_2_dict = co_oc.groupby(field_name_2)[field_name_1].apply(list).to_dict() + field_name_1_dict = co_oc.groupby(field_name_1)[field_name_2].apply(list).to_dict() + # Assign the "predicate" column using self._assign_predicate method + co_oc["predicate"] = co_oc.apply( + AnndataAnalyzer._assign_predicate, + args=( + field_name_1, + field_name_2, + field_name_1_dict, + field_name_2_dict, + debug_mode, + ), + axis=1, + ) - """ + @staticmethod + def _assign_predicate( + row, field_name_1, field_name_2, field_name_1_dict, field_name_2_dict, debug + ): if debug: print("Debugging row:", row) print("Value of field_name_1:", row[field_name_1]) diff --git a/pandasaurus_cxg/graph_generator/graph_generator.py b/pandasaurus_cxg/graph_generator/graph_generator.py index 9656086..8743f36 100644 --- a/pandasaurus_cxg/graph_generator/graph_generator.py +++ b/pandasaurus_cxg/graph_generator/graph_generator.py @@ -142,19 +142,28 @@ def enrich_rdf_graph(self): for o, _, _ in self.graph.triples((None, RDFS.label, Literal(row["o_label"]))): self.graph.add((s, RDFS.subClassOf, o)) - def save_rdf_graph(self, file_name: str = "mygraph", _format: str = "xml"): + def save_rdf_graph( + self, + graph: Optional[Graph], + file_name: Optional[str] = "mygraph", + _format: Optional[str] = "xml", + ): """ Serializes and saves the RDF graph to a file. Args: - file_name (str, optional): The name of the output file without the extension. + graph: An optional RDF graph that will be serialized. + If provided, this graph will be used for serialization. + If not provided, the graph inside the GraphGenerator instance will be used. + file_name: The name of the output file without the extension. Defaults to "mygraph". - _format (str, optional): The format of the RDF serialization. Defaults to "xml". + _format: The format of the RDF serialization. Defaults to "xml". Raises: InvalidGraphFormat: If the provided _format is not valid. """ + graph = graph if graph else self.graph format_extension = { RDFFormat.RDF_XML.value: "owl", RDFFormat.TURTLE.value: "ttl", @@ -163,50 +172,65 @@ def save_rdf_graph(self, file_name: str = "mygraph", _format: str = "xml"): if _format in format_extension: file_extension = format_extension[_format] - self.graph.serialize(f"{file_name}.{file_extension}", format=_format) + graph.serialize(f"{file_name}.{file_extension}", format=_format) else: valid_formats = [valid_format.value for valid_format in RDFFormat] raise InvalidGraphFormat(RDFFormat, valid_formats) - def visualize_rdf_graph(self, start_node: List[str], file_path: str): + def visualize_rdf_graph(self, start_node: List[str], predicate: str, file_path: str): + # TODO visualize all graph, with parametric annotation properties to better visualize the nodes. + # TODO apply redundancy striping to owl directly graph = Graph().parse(file_path, format="ttl") if file_path else self.graph - for node in start_node: - if not URIRef(node) in graph.subjects() or URIRef(node) in graph.objects(): - raise ValueError(f"None of the nodes in the list {node} exist in the RDF graph.") - + if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"): + raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph") + if start_node: + for node in start_node: + if not URIRef(node) in graph.subjects(): + raise ValueError( + f"None of the nodes in the list {node} exist in the RDF graph." + ) visited = set() - stack = [URIRef(node) for node in start_node] subgraph = Graph() + stack = [URIRef(node) for node in start_node] if start_node else None while stack: node = stack.pop() if node not in visited: visited.add(node) - for subject, predicate, obj in graph.triples((node, None, None)): - if predicate != RDF.type: - subgraph.add( - (subject, predicate, obj) - ) # Add all outgoing edges of the current node - for s, p, next_node in graph.triples((node, None, None)): - # stack.append(next_node) - if not isinstance(next_node, BNode): - stack.append(next_node) - else: - subgraph.add( - ( - node, - next(graph.objects(next_node, OWL.onProperty)), - next(graph.objects(next_node, OWL.someValuesFrom)), - ) - ) + for s, p, o in graph.triples((node, self.ns[predicate] if predicate else None, None)): + # Add all outgoing edges of the current node + subgraph.add((s, p, o)) + for s, p, next_node in graph.triples( + (node, self.ns[predicate] if predicate else None, None) + ): + if not isinstance(next_node, BNode): + stack.append(next_node) + # else: + # subgraph.add( + # ( + # node, + # next(graph.objects(next_node, OWL.onProperty)), + # next(graph.objects(next_node, OWL.someValuesFrom)), + # ) + # ) + # TODO not sure if we need this else or not + + if not start_node: + for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)): + # Add all outgoing edges of the current node + subgraph.add((s, p, o)) nx_graph = nx.DiGraph() for subject, predicate, obj in subgraph: - if isinstance(obj, URIRef): - edge_data = { - "label": "is_a" if predicate == RDF.type else str(predicate).split("/")[-1] - } - nx_graph.add_edge(str(subject).split("/")[-1], str(obj).split("/")[-1], **edge_data) + if isinstance(obj, URIRef) and predicate != RDF.type: + edge_data = {"label": str(predicate).split("/")[-1]} + nx_graph.add_edge( + str(subject).split("/")[-1], + str(obj).split("/")[-1], + **edge_data, + ) + elif predicate != RDF.type: + nx_graph.add_node(str(subject).split("/")[-1], label=str(obj)) # Apply transitive reduction to remove redundancy transitive_reduction_graph = nx.transitive_reduction(nx_graph) @@ -219,12 +243,14 @@ def visualize_rdf_graph(self, start_node: List[str], file_path: str): pos = nx.drawing.nx_agraph.graphviz_layout(transitive_reduction_graph, prog="dot") # Plot the graph as a hierarchical tree - plt.figure(figsize=(10, 8)) + node_labels = nx.get_node_attributes(transitive_reduction_graph, "label") + plt.figure(figsize=(10, 10)) nx.draw( transitive_reduction_graph, pos, with_labels=True, - node_size=1500, + labels=node_labels, + node_size=1000, node_color="skyblue", font_size=8, font_weight="bold", @@ -241,6 +267,45 @@ def visualize_rdf_graph(self, start_node: List[str], file_path: str): ) plt.show() + def transitive_reduction(self, predicate: str, file_path: str, _format: str = "xml"): + graph = Graph().parse(file_path, format="ttl") if file_path else self.graph + if predicate and not graph.query(f"ASK {{ ?s {self.ns[predicate].n3()} ?o }}"): + raise ValueError(f"The {self.ns[predicate]} relation does not exist in the graph") + + subgraph = Graph() + for s, p, o in graph.triples((None, self.ns[predicate] if predicate else None, None)): + # Add all outgoing edges of the current node + subgraph.add((s, p, o)) + + nx_graph = nx.DiGraph() + for subject, _predicate, obj in subgraph: + if isinstance(obj, URIRef) and _predicate != RDF.type: + edge_data = { + "label": "is_a" if _predicate == RDF.type else str(predicate).split("/")[-1] + } + nx_graph.add_edge( + str(subject).split("/")[-1], + str(obj).split("/")[-1], + **edge_data, + ) + + # Apply transitive reduction to remove redundancy + transitive_reduction_graph = nx.transitive_reduction(nx_graph) + transitive_reduction_graph.add_edges_from( + (u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges + ) + + # Remove redundant triples using nx graph + edge_diff = list(set(nx_graph.edges) - set(transitive_reduction_graph.edges)) + + for edge in edge_diff: + if graph.query( + f"ASK {{ {self.ns[edge[0]].n3()} {self.ns[predicate].n3()} {self.ns[edge[1]].n3()} }}" + ): + graph.remove((self.ns[edge[0]], self.ns[predicate], self.ns[edge[1]])) + + self.save_rdf_graph(graph, f"{file_path.split('.')[0]}_non_redundant", _format) + class RDFFormat(Enum): RDF_XML = "xml"