Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor bug fixes and improvements manual testing notes #51

Merged
merged 8 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pandasaurus_cxg/anndata_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __init__(self, anndata: AnnData, author_cell_type_list: Optional[List[str]]
raise ValueError(
"AnndataAnalyzer initialization error:\n\n"
"The 'obs_meta' field is missing in anndata.uns!\n"
"If this field is absent, you can provide a list of field names from the AnnData file "
"using the free_text_fields parameter.\n"
f"Available free text fields are: {', '.join(available_free_text_fields)}"
"If this field is absent, you can provide a list of field names from the "
"AnnData file using the author_cell_type_list parameter.\n"
f"Available author cell type fields are: {', '.join(available_free_text_fields)}"
)
self.report_df = pd.DataFrame()

Expand Down
10 changes: 8 additions & 2 deletions pandasaurus_cxg/anndata_enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def __init__(
anndata: The AnnData object.
cell_type_field: The cell type information in the anndata object.
Defaults to "cell_type_ontology_term_id".
context_field: The context information in the anndata object.
context_field: Ontology ID of the context information in the anndata object.
Defaults to "tissue_ontology_term_id".
context_field_label: Label of the context information in the anndata object.
Defaults to "tissue".
ontology_list_for_slims: The ontology list for generating the slim list.
The slim list is used in minimal_slim_enrichment and full_slim_enrichment.
Defaults to "Cell Ontology"
Expand Down Expand Up @@ -62,6 +64,7 @@ def from_file_path(
file_path: str,
cell_type_field: Optional[str] = "cell_type_ontology_term_id",
context_field: Optional[str] = "tissue_ontology_term_id",
context_field_label: Optional[str] = "tissue",
ontology_list_for_slims: Optional[List[str]] = None,
):
"""Initialize the AnndataEnricher instance with file path.
Expand All @@ -71,8 +74,10 @@ def from_file_path(
file_path: The path to the file containing the anndata object.
cell_type_field: The cell type information in the anndata object.
Defaults to "cell_type_ontology_term_id".
context_field: The context information in the anndata object.
context_field: Ontology ID of the context information in the anndata object.
Defaults to "tissue_ontology_term_id".
context_field_label: Label of the context information in the anndata object.
Defaults to "tissue".
ontology_list_for_slims: The ontology list for generating the slim list.
The slim list is used in minimal_slim_enrichment and full_slim_enrichment.
Defaults to "Cell Ontology"
Expand All @@ -83,6 +88,7 @@ def from_file_path(
AnndataLoader.load_from_file(file_path),
cell_type_field,
context_field,
context_field_label,
ontology_list_for_slims,
)

Expand Down
2 changes: 1 addition & 1 deletion pandasaurus_cxg/enrichment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class AnndataEnrichmentAnalyzer:
def __init__(self, file_path: str, author_cell_type_list: Optional[str] = None):
def __init__(self, file_path: str, author_cell_type_list: Optional[List[str]] = None):
"""
Initializes the AnndataEnrichmentAnalyzer, a wrapper for AnndataEnricher and AnndataAnalyzer.

Expand Down
34 changes: 30 additions & 4 deletions pandasaurus_cxg/graph_generator/graph_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
add_edge,
add_node,
add_outgoing_edges_to_subgraph,
colour_mapping,
find_and_rotate_center_layout,
generate_subgraph,
select_node_with_property,
Expand Down Expand Up @@ -292,7 +293,17 @@ def visualize_rdf_graph(
if isinstance(o, URIRef) and p != RDF.type:
add_edge(nx_graph, s, p, o)
elif p == RDFS.label:
add_node(nx_graph, s, o)
add_node(nx_graph, s, {"label": str(o)})
elif p == RDF.type:
add_node(nx_graph, s, {"type": str(o)})

# Identify and remove nodes without any edge
# cell cluster type generate a node independent of the whole graph. this fix it
if len(nx_graph.nodes()) != 1:
nodes_to_remove = [
node for node, degree in dict(nx_graph.degree()).items() if degree == 0
]
nx_graph.remove_nodes_from(nodes_to_remove)

# Identify and remove nodes without any edge
# cell cluster type generate a node independent of the whole graph. this fix it
Expand All @@ -307,6 +318,17 @@ def visualize_rdf_graph(
(u, v, nx_graph.edges[u, v]) for u, v in transitive_reduction_graph.edges
)

node_colors = []
# Get node colors based on node types
for node in transitive_reduction_graph.nodes:
node_colors.append(
colour_mapping.get(transitive_reduction_graph.nodes[node]["type"], "red")
)
# node_colors = [
# colour_mapping[transitive_reduction_graph.nodes[node]["type"]]
# for node in transitive_reduction_graph.nodes
# ]

pos = find_and_rotate_center_layout(transitive_reduction_graph)
plt.figure(figsize=(10, 10))

Expand All @@ -321,17 +343,20 @@ def visualize_rdf_graph(
with_labels=True,
labels=node_labels,
node_size=2000,
node_color="skyblue",
node_color=node_colors,
font_size=8,
font_weight="bold",
)
# Draw edge labels on the graph
edge_labels = nx.get_edge_attributes(transitive_reduction_graph, "label")
edge_labels_formatted = {edge: label for edge, label in edge_labels.items()}
edge_labels = {
edge: "\n".join(textwrap.wrap(label, width=10)) for edge, label in edge_labels.items()
}
# edge_labels_formatted = {edge: label for edge, label in edge_labels.items()}
nx.draw_networkx_edge_labels(
transitive_reduction_graph,
pos,
edge_labels=edge_labels_formatted,
edge_labels=edge_labels,
font_size=8,
font_color="red",
)
Expand Down Expand Up @@ -423,6 +448,7 @@ def set_label_adding_priority(self, label_priority: Union[List[str], Dict[str, i
order for adding labels.

"""
label_priority.append("cell_type") if "cell_type" not in label_priority else None
if isinstance(label_priority, list):
self.label_priority = {
label: len(label_priority) - i for i, label in enumerate(label_priority)
Expand Down
50 changes: 29 additions & 21 deletions pandasaurus_cxg/graph_generator/graph_generator_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import networkx as nx
from rdflib import BNode, RDF, RDFS, Graph, Literal, Namespace, OWL, URIRef
from rdflib import OWL, RDF, RDFS, BNode, Graph, Literal, Namespace, URIRef

from pandasaurus_cxg.graph_generator.graph_predicates import (
CLUSTER,
CONSIST_OF,
SUBCLUSTER_OF,
)

colour_mapping = {
"http://www.w3.org/2002/07/owl#Class": "deepskyblue",
"http://purl.obolibrary.org/obo/PCL_0010001": "cyan",
}


def add_edge(nx_graph: nx.Graph, subject, predicate, obj):
edge_data = {
"label": str(predicate).split("#")[-1]
if "#" in predicate
else str(predicate).split("/")[-1]
"label": (
CONSIST_OF["label"]
if str(predicate) == CONSIST_OF["iri"]
else SUBCLUSTER_OF["label"]
if str(predicate) == SUBCLUSTER_OF["iri"]
else CLUSTER["label"]
if str(predicate) == CLUSTER["iri"]
else str(predicate).split("#")[-1]
if predicate and "#" in predicate
else str(predicate).split("/")[-1]
)
}
nx_graph.add_edge(
str(subject),
Expand All @@ -15,8 +34,9 @@ def add_edge(nx_graph: nx.Graph, subject, predicate, obj):
)


def add_node(nx_graph: nx.Graph, subject, obj):
nx_graph.add_node(str(subject), label=str(obj))
def add_node(nx_graph: nx.Graph, subject, annotation):
# nx_graph.add_node(str(subject), annotation=str(obj))
nx_graph.add_node(str(subject), **annotation)


def add_outgoing_edges_to_subgraph(graph, predicate_uri=None):
Expand Down Expand Up @@ -61,6 +81,7 @@ def find_and_rotate_center_layout(graph):
rotated_pos = {node: (2 * x_center - x, 2 * y_center - y) for node, (x, y) in pos.items()}
return rotated_pos


def generate_subgraph(graph, predicate_uri, stack, bottom_up):
subgraph = Graph()
visited = set()
Expand All @@ -70,7 +91,7 @@ def generate_subgraph(graph, predicate_uri, stack, bottom_up):
visited.add(node)
for s, p, o in graph.triples((node, predicate_uri, None)):
# Add all outgoing edges of the current node
if isinstance(o, Literal):
if isinstance(o, Literal) or p == RDF.type and not isinstance(o, BNode):
subgraph.add((s, p, o))
if bottom_up:
triples = graph.triples((node, predicate_uri, None))
Expand All @@ -90,22 +111,9 @@ def generate_subgraph(graph, predicate_uri, stack, bottom_up):
stack.append(_o)
else:
stack.append(_s)
# for s, p, next_node in graph.triples((node, predicate_uri, None)):
# if not isinstance(next_node, BNode):
# stack.append(next_node)
# else:
# _p = next(graph.objects(next_node, OWL.onProperty))
# _o = next(graph.objects(next_node, OWL.someValuesFrom))
# subgraph.add(
# (
# node,
# _p,
# _o,
# )
# )
# stack.append(_o)
return subgraph


def select_node_with_property(graph: Graph, _property: str, value: str):
ns = Namespace({k: v for k, v in graph.namespaces()}.get("ns"))
if _property == "label":
Expand Down
2 changes: 1 addition & 1 deletion pandasaurus_cxg/graph_generator/graph_predicates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CONSIST_OF = {"iri": "http://purl.obolibrary.org/obo/RO_0002473", "label": "composed_primarily_of"}
CONSIST_OF = {"iri": "http://purl.obolibrary.org/obo/RO_0002473", "label": "composed primarily of"}

SUBCLUSTER_OF = {"iri": "http://purl.obolibrary.org/obo/RO_0015003", "label": "subcluster of"}

Expand Down