Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test case implementation #52

Merged
merged 10 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[report]
exclude_lines =
@abstractmethod
raise NotImplementedError
print\(.*\)
33 changes: 23 additions & 10 deletions pandasaurus_cxg/anndata_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, anndata: AnnData, author_cell_type_list: Optional[List[str]]
for meta in obs_meta
if meta.get("field_type") == "author_cell_type_label"
] + ["cell_type"]
except KeyError as e:
except KeyError:
if author_cell_type_list:
self.all_cell_type_identifiers = author_cell_type_list + ["cell_type"]
self._anndata.uns["obs_meta"] = json.dumps(
Expand Down Expand Up @@ -105,6 +105,13 @@ def co_annotation_report(self, disease: Optional[str] = None, enrich: bool = Fal
pd.DataFrame: The co-annotation report.

"""
# TODO needs a refactoring about what enrichment method to use. Or would it better to accept
# enriched_df as parameter, so users get to decide?
enriched_co_oc = None
if enrich:
enricher = AnndataEnricher(self._anndata)
enricher.simple_enrichment()
enriched_co_oc = AnndataAnalyzer._enrich_co_annotation(enricher)
temp_result = []
for field_name_2 in self.all_cell_type_identifiers:
for field_name_1 in self.all_cell_type_identifiers:
Expand All @@ -118,7 +125,15 @@ def co_annotation_report(self, disease: Optional[str] = None, enrich: bool = Fal
)

if enrich:
co_oc = self._enrich_co_annotation(co_oc, field_name_1, field_name_2)
co_oc = pd.concat(
[
co_oc,
enriched_co_oc.rename(
columns={"s_label": field_name_1, "o_label": field_name_2}
),
],
axis=0,
).reset_index(drop=True)

AnndataAnalyzer._assign_predicate_column(co_oc, field_name_1, field_name_2)
temp_result.extend(co_oc.to_dict(orient="records"))
Expand Down Expand Up @@ -155,14 +170,12 @@ def enriched_co_annotation_report(self, disease: Optional[str] = None):
"""
return self.co_annotation_report(disease, True)

def _enrich_co_annotation(self, co_oc, field_name_1, field_name_2):
enricher = AnndataEnricher(self._anndata)
simple = enricher.simple_enrichment()
df = simple[simple["o"].isin(enricher.get_seed_list())][["s_label", "o_label"]].rename(
columns={"s_label": field_name_1, "o_label": field_name_2}
)
co_oc = pd.concat([co_oc, df], axis=0).reset_index(drop=True)
return co_oc
@staticmethod
def _enrich_co_annotation(enricher: AnndataEnricher):
enriched_df = enricher.enricher.enriched_df
if enriched_df.empty:
return pd.DataFrame()
return enriched_df[enriched_df["o"].isin(enricher.seed_list)][["s_label", "o_label"]]

def _filter_data_and_drop_duplicates(self, field_name_1, field_name_2, disease):
# Filter the data based on the disease condition
Expand Down
77 changes: 48 additions & 29 deletions pandasaurus_cxg/anndata_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pandasaurus_cxg.utils.exceptions import (
CellTypeNotFoundError,
InvalidSlimName,
MissingEnrichmentProcess,
SubclassWarning,
)

Expand Down Expand Up @@ -44,15 +45,22 @@ def __init__(
if ontology_list_for_slims is None:
ontology_list_for_slims = ["Cell Ontology"]
# TODO Do we need to keep whole anndata? Would it be enough to keep the obs only?
self._anndata = anndata
self._seed_list = self._anndata.obs[cell_type_field].unique().tolist()
self.enricher = Query(self._seed_list)
unique_context = self._anndata.obs[[context_field, context_field_label]].drop_duplicates()
self._context_list = (
None
if context_field not in self._anndata.obs.keys()
else dict(zip(unique_context[context_field], unique_context[context_field_label]))
)
self.anndata = anndata
self.seed_list = self.anndata.obs[cell_type_field].unique().tolist()
self.enricher = Query(self.seed_list)
try:
unique_context = self.anndata.obs[
[context_field, context_field_label]
].drop_duplicates()
self._context_list = (
None
if context_field not in self.anndata.obs.keys()
else dict(zip(unique_context[context_field], unique_context[context_field_label]))
)
except KeyError as e:
raise KeyError(
"Please use a valid 'context_field' and 'context_field_label' that exist in your anndata file."
)
self.slim_list = [
slim
for ontology in ontology_list_for_slims
Expand Down Expand Up @@ -132,6 +140,7 @@ def contextual_slim_enrichment(self) -> Optional[pd.DataFrame]:
otherwise None.
"""
# TODO Better handle datasets without tissue field
# TODO self._context_list is refactored and cannot be None in any case. 'else' needs an update
return (
self.enricher.contextual_slim_enrichment(list(self._context_list.keys()))
if self._context_list
Expand All @@ -152,7 +161,6 @@ def filter_anndata_with_enriched_cell_type(self, cell_type: str) -> pd.DataFrame
CellTypeNotFoundError: If the provided cell_type is not found in the enriched cell types.

"""
# TODO Add empty dataframe exception
cell_type_dict = self.create_cell_type_dict()
if cell_type not in cell_type_dict:
raise CellTypeNotFoundError([cell_type], cell_type_dict.keys())
Expand All @@ -162,8 +170,8 @@ def filter_anndata_with_enriched_cell_type(self, cell_type: str) -> pd.DataFrame
].tolist()
cell_type_group.append(cell_type)

return self._anndata.obs[
self._anndata.obs["cell_type_ontology_term_id"].isin(cell_type_group)
return self.anndata.obs[
self.anndata.obs["cell_type_ontology_term_id"].isin(cell_type_group)
]

def annotate_anndata_with_cell_type(
Expand Down Expand Up @@ -195,7 +203,6 @@ def annotate_anndata_with_cell_type(
one cell type is a subclass of another, indicating a potential issue with the
provided annotations.
"""
# TODO Add empty dataframe exception
cell_type_dict = self.create_cell_type_dict()
# Check if any cell_type in cell_type_list is not in cell_type_dict
missing_cell_types = set(cell_type_list) - set(cell_type_dict.keys())
Expand All @@ -208,18 +215,18 @@ def annotate_anndata_with_cell_type(
raise SubclassWarning(subclass_relation)

# annotation phase
self._anndata.obs[field_name] = ""
condition = self._anndata.obs["cell_type_ontology_term_id"].isin(cell_type_list)
self._anndata.obs.loc[condition, field_name] = field_value
return self._anndata.obs[self._anndata.obs[field_name] == field_value]
self.anndata.obs[field_name] = ""
condition = self.anndata.obs["cell_type_ontology_term_id"].isin(cell_type_list)
self.anndata.obs.loc[condition, field_name] = field_value
return self.anndata.obs[self.anndata.obs[field_name] == field_value]

def set_enricher_property_list(self, property_list: List[str]):
"""Set the property list for the enricher.

Args:
property_list (List[str]): The list of properties to include in the enrichment analysis.
"""
self.enricher = Query(self._seed_list, property_list)
self.enricher = Query(self.seed_list, property_list)

def validate_slim_list(self, slim_list):
"""Check if any slim term in the given list is invalid.
Expand All @@ -237,17 +244,22 @@ def validate_slim_list(self, slim_list):
if invalid_slim_list:
raise InvalidSlimName(invalid_slim_list, self.slim_list)

def get_seed_list(self):
return self._seed_list

def get_anndata(self):
return self._anndata
def create_cell_type_dict(self):
"""
Create a dictionary from enriched_df for mapping cell type ontology term IDs to their labels.

def get_context_list(self):
return self._context_list
Returns:
A dictionary where keys are cell type ontology term IDs (e.g., "CL:000001") and values are
corresponding cell type labels (e.g., "Neuron").

def create_cell_type_dict(self):
# TODO Add empty dataframe exception
Raises:
MissingEnrichmentProcess: If the enrichment process has not been performed, and the
`enriched_df` is empty.
"""
if self.enricher.enriched_df.empty:
enrichment_methods = [i for i in dir(AnndataEnricher) if "_enrichment" in i]
enrichment_methods.sort()
raise MissingEnrichmentProcess(enrichment_methods)
return (
pd.concat(
[
Expand All @@ -266,16 +278,23 @@ def create_cell_type_dict(self):

def check_subclass_relationships(self, cell_type_list: List[str]) -> List[Tuple[str, str]]:
"""
Check for subclass relationships between cell type ontology terms.
Check for subclass relationships between cell type ontology terms using enriched_df.

Args:
cell_type_list: A list of cell type ontology term IDs to be used
for cell type annotation.

Returns:
A list of cell type pairs that have a subClassOf relationship between them.

Raises:
MissingEnrichmentProcess: If the enrichment process has not been performed, and the
`enriched_df` is empty.
"""
# TODO Add empty dataframe exception
if self.enricher.enriched_df.empty:
enrichment_methods = [i for i in dir(AnndataEnricher) if "_enrichment" in i]
enrichment_methods.sort()
raise MissingEnrichmentProcess(enrichment_methods)
subclass_relation = []
for s, o in itertools.combinations(cell_type_list, 2):
if not self.enricher.enriched_df[
Expand Down
2 changes: 1 addition & 1 deletion pandasaurus_cxg/enrichment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ def enriched_co_annotation_report(self, disease: Optional[str] = None):
pd.DataFrame: The co-annotation report.

"""
return self.analyzer_manager.enriched_co_annotation_report(disease, True)
return self.analyzer_manager.enriched_co_annotation_report(disease)
74 changes: 11 additions & 63 deletions pandasaurus_cxg/graph_generator/graph_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ def __init__(
enrichment_analyzer: A wrapper object for AnndataEnricher and AnndataAnalyzer.
keys (Optional[List[str]]): List of column names to select from the DataFrame to
generate the report. Defaults to None.
Please refrain from using this parameter until the next notification

"""
# TODO need to think about how to handle the requirement of enrichment and co_annotation_analysis methods
self.ea = enrichment_analyzer
# TODO need to handle invalid keys. We also need to discuss about keeping the keys param. DO NOT USE
self.df = (
self.ea.analyzer_manager.report_df[keys] if keys else self.ea.analyzer_manager.report_df
enrichment_analyzer.analyzer_manager.report_df[keys]
if keys
else enrichment_analyzer.analyzer_manager.report_df
)
if self.ea.enricher_manager.enricher.enriched_df.empty:
# TODO or we can just call simple_enrichment method
Expand Down Expand Up @@ -212,7 +217,7 @@ def save_rdf_graph(
graph.serialize(f"{file_name}.{file_extension}", format=_format)
else:
valid_formats = [valid_format.value for valid_format in RDFFormat]
raise InvalidGraphFormat(RDFFormat, valid_formats)
raise InvalidGraphFormat(_format, valid_formats)

def visualize_rdf_graph(
self,
Expand Down Expand Up @@ -305,12 +310,6 @@ def visualize_rdf_graph(
]
nx_graph.remove_nodes_from(nodes_to_remove)

# Identify and remove nodes without any edge
# cell cluster type generate a node independent of the whole graph. this fix it
if len(nx_graph.nodes()) != 1:
nodes_to_remove = [node for node, degree in dict(nx_graph.degree()).items() if degree == 0]
nx_graph.remove_nodes_from(nodes_to_remove)

# Apply transitive reduction to remove redundancy
transitive_reduction_graph = nx.transitive_reduction(nx_graph)
transitive_reduction_graph.add_nodes_from(nx_graph.nodes(data=True))
Expand All @@ -324,10 +323,6 @@ def visualize_rdf_graph(
node_colors.append(
colour_mapping.get(transitive_reduction_graph.nodes[node]["type"], "red")
)
# node_colors = [
# colour_mapping[transitive_reduction_graph.nodes[node]["type"]]
# for node in transitive_reduction_graph.nodes
# ]

pos = find_and_rotate_center_layout(transitive_reduction_graph)
plt.figure(figsize=(10, 10))
Expand Down Expand Up @@ -362,61 +357,12 @@ def visualize_rdf_graph(
)
plt.show()

def transitive_reduction(self, predicate_list: List[str], file_path: str, _format: str = "xml"):
# TODO We do not need this anymore since it is moved to pandasaurus
graph = Graph().parse(file_path, format="ttl") if file_path else self.graph
invalid_predicates = []
for predicate in predicate_list:
if predicate and not graph.query(f"ASK {{ ?s <{predicate}> ?o }}"):
invalid_predicates.append(predicate)
continue

predicate_uri = URIRef(predicate) if predicate else None
subgraph = add_outgoing_edges_to_subgraph(graph, predicate_uri)

nx_graph = nx.DiGraph()
for s, p, o in subgraph:
if isinstance(o, URIRef) and p != RDF.type:
add_edge(nx_graph, s, predicate, o)

# Apply transitive reduction to remove redundancy
transitive_reduction_graph = nx.transitive_reduction(nx_graph)
transitive_reduction_graph.add_edges_from(
(u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges
)
# Remove redundant triples using nx graph
edge_diff = list(set(nx_graph.edges) - set(transitive_reduction_graph.edges))
for edge in edge_diff:
if graph.query(f"ASK {{ <{edge[0]}> <{predicate}> <{edge[1]}> }}"):
graph.remove((URIRef(edge[0]), URIRef(predicate), URIRef(edge[1])))
logger.info(f"Transitive reduction has been applied on {predicate}.")

self.save_rdf_graph(graph, f"{file_path.split('.')[0]}_non_redundant", _format)
logger.info(f"{file_path.split('.')[0]}_non_redundant has been saved.")

if invalid_predicates:
error_msg = (
f"The predicate '{invalid_predicates[0]}' does not exist in the graph"
if len(invalid_predicates) == 1
else f"The predicates {' ,'.join(invalid_predicates)} do not exist in the graph"
)
logger.error(error_msg)

def add_label_to_terms(self, graph_: Graph = None):
if not self.label_priority:
raise ValueError(
"The priority order for adding labels is missing. Please use set_label_adding_priority method."
)
graph = graph_ if graph_ else self.graph
# TODO have a better way to handle priority assignment and have an auto default assignment
# priority = {
# "subclass.l3": 1,
# "subclass.l2": 2,
# "subclass.full": 3,
# "subclass.l1": 4,
# "cell_type": 5,
# "class": 6,
# }
priority = self.label_priority
unique_subjects_query = (
"SELECT DISTINCT ?subject WHERE { ?subject ?predicate ?object FILTER (isIRI(?subject))}"
Expand All @@ -443,13 +389,14 @@ def set_label_adding_priority(self, label_priority: Union[List[str], Dict[str, i
Set the priority order for adding labels.

Args:
label_priority (Optional[Union[List[str], Dict[str, int]]]): Either a list of strings,
label_priority (Union[List[str], Dict[str, int]]): Either a list of strings,
a dictionary with string keys and int values, representing the priority
order for adding labels.

"""
label_priority.append("cell_type") if "cell_type" not in label_priority else None
if isinstance(label_priority, list):
# TODO Do we need to append the 'cell_type'?
label_priority.append("cell_type") if "cell_type" not in label_priority else None
self.label_priority = {
label: len(label_priority) - i for i, label in enumerate(label_priority)
}
Expand All @@ -459,6 +406,7 @@ def set_label_adding_priority(self, label_priority: Union[List[str], Dict[str, i
isinstance(key, str) and isinstance(value, int)
for key, value in label_priority.items()
):
# TODO Do we need to append the 'cell_type'?
self.label_priority = label_priority
else:
raise ValueError("Invalid types in priority dictionary")
Expand Down
4 changes: 2 additions & 2 deletions pandasaurus_cxg/graph_generator/graph_generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def find_and_rotate_center_layout(graph):
rotated_pos = {node: (2 * x_center - x, 2 * y_center - y) for node, (x, y) in pos.items()}
return rotated_pos


def generate_subgraph(graph, predicate_uri, stack, bottom_up):
subgraph = Graph()
visited = set()
Expand Down Expand Up @@ -113,7 +113,7 @@ def generate_subgraph(graph, predicate_uri, stack, bottom_up):
stack.append(_s)
return subgraph


def select_node_with_property(graph: Graph, _property: str, value: str):
ns = Namespace({k: v for k, v in graph.namespaces()}.get("ns"))
if _property == "label":
Expand Down
Loading