Skip to content

Commit

Permalink
Merge branch 'master' into doublet-detection
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi authored Mar 21, 2024
2 parents 1a7c5ac + 41dbff8 commit 9228c61
Show file tree
Hide file tree
Showing 11 changed files with 847 additions and 95 deletions.
41 changes: 29 additions & 12 deletions cassiopeia/data/CassiopeiaTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

import ete3
import heapq
import networkx as nx
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1595,7 +1596,7 @@ def remove_leaves_and_prune_lineages(
Removes the specified leaves and all ancestors of those leaves that are
no longer the ancestor of any of the remaining leaves. In the context
of a phylogeny, this prunes the lineage of all nodes no longer relevant
to observed samples. Additionally, maintains consistency with the
to observed samples. Additionally, maintains consistency with the
updated tree by removing the node from all leaf data.
Args:
Expand All @@ -1614,19 +1615,35 @@ def remove_leaves_and_prune_lineages(
for n in nodes:
if not self.is_leaf(n):
raise CassiopeiaTreeError("A specified node is not a leaf.")

# Keep track of nodes to check and their depths
nodes_to_check = set()
nodes_depth_queue = []

# Remove leaves from the tree
for n in nodes:
if len(self.nodes) == 1:
self.__remove_node(n)
else:
curr_parent = self.parent(n)
self.__remove_node(n)
while len(self.children(curr_parent)) < 1 and not self.is_root(
curr_parent
):
next_parent = self.parent(curr_parent)
self.__remove_node(curr_parent)
curr_parent = next_parent
parent = next(self.__network.predecessors(n))
if parent not in nodes_to_check:
parent_time = self.__network.nodes[parent]["time"]
heapq.heappush(nodes_depth_queue, (parent_time, parent))
nodes_to_check.add(parent)
self.__network.remove_node(n)

# Check nodes with a children removed from bottom to top
while len(nodes_to_check) > 0:
_, n = heapq.heappop(nodes_depth_queue)
nodes_to_check.remove(n)
if n == self.root:
continue
children = list(self.__network.successors(n))
# Remove nodes with no children
if len(children) == 0:
parent = next(self.__network.predecessors(n))
parent_time = self.__network.nodes[parent]["time"]
self.__network.remove_node(n)
if parent not in nodes_to_check:
heapq.heappush(nodes_depth_queue, (parent_time, parent))
nodes_to_check.add(parent)

# Remove all removed nodes from data fields
# This function will also clear the cache
Expand Down
45 changes: 31 additions & 14 deletions cassiopeia/data/utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
General utilities for the datasets encountered in Cassiopeia.
"""

import collections
from joblib import delayed
import multiprocessing
Expand Down Expand Up @@ -61,17 +62,35 @@ def get_lca_characters(
all_states = [
vec[i] for vec in vecs if vec[i] != missing_state_indicator
]
chars = set.intersection(
*map(
set,
[
state if is_ambiguous_state(state) else [state]
for state in all_states
],

# this check is specifically if all_states consists of a single
# ambiguous state.
if len(list(set(all_states))) == 1:
state = all_states[0]
# lca_vec[i] = state
if is_ambiguous_state(state) and len(state) == 1:
lca_vec[i] = state[0]
else:
lca_vec[i] = all_states[0]
else:
all_ambiguous = np.all(
[is_ambiguous_state(s) for s in all_states]
)
)
if len(chars) == 1:
lca_vec[i] = list(chars)[0]
chars = set.intersection(
*map(
set,
[
state if is_ambiguous_state(state) else [state]
for state in all_states
],
)
)
if len(chars) == 1:
lca_vec[i] = list(chars)[0]
if all_ambiguous:
# if we only have ambiguous states, we set the LCA state
# to be the intersection.
lca_vec[i] = tuple(chars)
return lca_vec


Expand Down Expand Up @@ -109,7 +128,7 @@ def ete3_to_networkx(tree: ete3.Tree) -> nx.DiGraph:
if n.is_root():
continue

g.add_edge(n.up.name, n.name)
g.add_edge(n.up.name, n.name, length=n.dist)

return g

Expand Down Expand Up @@ -217,9 +236,7 @@ def compute_dissimilarity_map(
]

# load character matrix into shared memory
shm = shared_memory.SharedMemory(
create=True, size=cm.nbytes
)
shm = shared_memory.SharedMemory(create=True, size=cm.nbytes)
shared_cm = np.ndarray(cm.shape, dtype=cm.dtype, buffer=shm.buf)
shared_cm[:] = cm[:]

Expand Down
Loading

0 comments on commit 9228c61

Please sign in to comment.