Skip to content

Commit

Permalink
Tests/spatiotemporalproblem (#464)
Browse files Browse the repository at this point in the history
* add more tests for spatiotemporalProblem

* move some functions from TemporalProblem to TemporalMixin

* add tests LineageProblem

* fix tests
  • Loading branch information
MUCDK authored Feb 5, 2023
1 parent fea4385 commit f152537
Show file tree
Hide file tree
Showing 8 changed files with 367 additions and 123 deletions.
21 changes: 20 additions & 1 deletion src/moscot/problems/spatio_temporal/_spatio_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@ class SpatioTemporalProblem(
AlignmentProblem[Numeric_t, BirthDeathProblem],
SpatialAlignmentMixin[Numeric_t, BirthDeathProblem],
):
"""Spatio-Temporal problem."""
"""
Class for analyzing time series spatial single cell data.
The `SpatioTemporalProblem` allows to model and analyze spatio-temporal single cell data
by matching cells belonging to two different time points via OT.
Parameters
----------
%(adata)s
"""

def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)
Expand All @@ -39,6 +48,7 @@ def prepare(
] = "sq_euclidean",
a: Optional[str] = None,
b: Optional[str] = None,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "SpatioTemporalProblem":
"""
Expand Down Expand Up @@ -73,6 +83,14 @@ def prepare(
# spatial key set in AlignmentProblem
# handle_joint_attr and handle_cost in AlignmentProblem
self.temporal_key = time_key
# TODO(michalk8): needs to be modified, move into BirthDeathMixin?
marginal_kwargs = dict(marginal_kwargs)
marginal_kwargs["proliferation_key"] = self.proliferation_key
marginal_kwargs["apoptosis_key"] = self.apoptosis_key
if a is None:
a = self.proliferation_key is not None or self.apoptosis_key is not None
if b is None:
b = self.proliferation_key is not None or self.apoptosis_key is not None

return super().prepare(
spatial_key=spatial_key,
Expand All @@ -83,6 +101,7 @@ def prepare(
cost=cost,
a=a,
b=b,
marginal_kwargs=marginal_kwargs,
**kwargs,
)

Expand Down
113 changes: 4 additions & 109 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from types import MappingProxyType
from typing import Any, Type, Tuple, Union, Literal, Mapping, Optional, TYPE_CHECKING

import pandas as pd

import numpy as np
from typing import Any, Type, Tuple, Union, Literal, Mapping, Optional

from anndata import AnnData

from moscot._types import Numeric_t, ScaleCost_t, ProblemStage_t, QuadInitializer_t, SinkhornInitializer_t
from moscot._docs._docs import d
from moscot.problems._utils import handle_cost, handle_joint_attr
from moscot.solvers._output import BaseSolverOutput
from moscot._constants._constants import Policy
from moscot.problems.time._mixins import TemporalMixin
from moscot.problems.base._birth_death import BirthDeathMixin, BirthDeathProblem
Expand All @@ -22,10 +17,10 @@ class TemporalProblem(
TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem]
):
"""
Class for analysing time series single cell data based on :cite:`schiebinger:19`.
Class for analyzing time series single cell data based on :cite:`schiebinger:19`.
The `TemporalProblem` allows to model and analyse time series single cell data by matching
cells from previous time points to later time points via optimal transport.
The `TemporalProblem` allows to model and analyze time series single cell data by matching
cells from previous time points to later time points via OT.
Based on the assumption that the considered cell modality is similar in consecutive time points
probabilistic couplings are computed between different time points.
This allows to understand cell trajectories by inferring ancestors and descendants of single cells.
Expand Down Expand Up @@ -184,106 +179,6 @@ def solve(
**kwargs,
) # type:ignore[return-value]

@property
def prior_growth_rates(self) -> Optional[pd.DataFrame]:
"""Return the prior estimate of growth rates of the cells in the source distribution."""
# TODO(michalk8): FIXME
cols = ["prior_growth_rates"]
df_list = [
pd.DataFrame(problem.prior_growth_rates, index=problem.adata.obs.index, columns=cols)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(
shape=(len(self.problems[tup].adata_tgt.obs), 1),
fill_value=np.nan,
),
index=self.problems[tup].adata_tgt.obs.index,
columns=cols,
)
)
return pd.concat(df_list, verify_integrity=True)

@property
def posterior_growth_rates(self) -> Optional[pd.DataFrame]:
"""Return the posterior estimate of growth rates of the cells in the source distribution."""
# TODO(michalk8): FIXME
cols = ["posterior_growth_rates"]
df_list = [
pd.DataFrame(problem.posterior_growth_rates, index=problem.adata.obs.index, columns=cols)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(
shape=(len(self.problems[tup].adata_tgt.obs), 1),
fill_value=np.nan,
),
index=self.problems[tup].adata_tgt.obs.index,
columns=cols,
)
)
return pd.concat(df_list, verify_integrity=True)

# TODO(michalk8): refactor me
@property
def cell_costs_source(self) -> Optional[pd.DataFrame]:
"""Return the cost of a cell obtained by the potentials of the optimal transport solution."""
sol = list(self.problems.values())[0].solution
if TYPE_CHECKING:
assert isinstance(sol, BaseSolverOutput)
if sol.potentials is None:
return None
df_list = [
pd.DataFrame(
np.array(np.abs(problem.solution.potentials[0])), # type: ignore[union-attr,index]
index=problem.adata_src.obs_names,
columns=["cell_cost_source"],
)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(shape=(len(self.problems[tup].adata_tgt.obs), 1), fill_value=np.nan),
index=self.problems[tup].adata_tgt.obs_names,
columns=["cell_cost_source"],
)
)
return pd.concat(df_list, verify_integrity=True)

@property
def cell_costs_target(self) -> Optional[pd.DataFrame]:
"""Return the cost of a cell (see online methods) obtained by the potentials of the OT solution."""
sol = list(self.problems.values())[0].solution
if TYPE_CHECKING:
assert isinstance(sol, BaseSolverOutput)
if sol.potentials is None:
return None

tup = list(self)[0]
df_list = [
pd.DataFrame(
np.full(shape=(len(self.problems[tup].adata_src), 1), fill_value=np.nan),
index=self.problems[tup].adata_src.obs_names,
columns=["cell_cost_target"],
)
]
df_list.extend(
[
pd.DataFrame(
np.array(np.abs(problem.solution.potentials[1])), # type: ignore[union-attr,index]
index=problem.adata_tgt.obs_names,
columns=["cell_cost_target"],
)
for problem in self.problems.values()
]
)
return pd.concat(df_list, verify_integrity=True)

@property
def _base_problem_type(self) -> Type[B]: # type: ignore[override]
return BirthDeathProblem # type: ignore[return-value]
Expand Down
112 changes: 109 additions & 3 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple, Union, Literal, Iterable, Optional, Protocol, TYPE_CHECKING
from typing import Any, Dict, List, Tuple, Union, Literal, Iterable, Iterator, Optional, Protocol, TYPE_CHECKING
from pathlib import Path
import itertools

Expand All @@ -10,18 +10,21 @@
from anndata import AnnData

from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.solvers._output import BaseSolverOutput
from moscot._docs._docs_mixins import d_mixins
from moscot._constants._constants import Key, PlottingKeys, PlottingDefaults
from moscot.problems.base._mixins import AnalysisMixin, AnalysisMixinProtocol
from moscot.solvers._tagged_array import Tag
from moscot.problems.base._birth_death import BirthDeathProblem
from moscot.problems.base._compound_problem import B, K, ApplyOutput_t


class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]):
# TODO(@MUCDK, @michalk8): check for ignore[misc] in line below, might become redundant
class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type:ignore
"""Protocol class."""

adata: AnnData
problems: Dict[Tuple[K, K], B]
problems: Dict[Tuple[K, K], BirthDeathProblem]
temporal_key: Optional[str]
_temporal_key: Optional[str]

Expand Down Expand Up @@ -134,6 +137,9 @@ def _get_interp_param(
) -> Numeric_t:
...

def __iter__(self) -> Iterator[Tuple[K, K]]:
...


class TemporalMixin(AnalysisMixin[K, B]):
"""Analysis Mixin for all problems involving a temporal dimension."""
Expand Down Expand Up @@ -408,6 +414,106 @@ def pull(
if return_data:
return result

@property
def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]:
"""Return the prior estimate of growth rates of the cells in the source distribution."""
# TODO(michalk8): FIXME
cols = ["prior_growth_rates"]
df_list = [
pd.DataFrame(problem.prior_growth_rates, index=problem.adata.obs.index, columns=cols)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(
shape=(len(self.problems[tup].adata_tgt.obs), 1),
fill_value=np.nan,
),
index=self.problems[tup].adata_tgt.obs.index,
columns=cols,
)
)
return pd.concat(df_list, verify_integrity=True)

@property
def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]:
"""Return the posterior estimate of growth rates of the cells in the source distribution."""
# TODO(michalk8): FIXME
cols = ["posterior_growth_rates"]
df_list = [
pd.DataFrame(problem.posterior_growth_rates, index=problem.adata.obs.index, columns=cols)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(
shape=(len(self.problems[tup].adata_tgt.obs), 1),
fill_value=np.nan,
),
index=self.problems[tup].adata_tgt.obs.index,
columns=cols,
)
)
return pd.concat(df_list, verify_integrity=True)

# TODO(michalk8): refactor me
@property
def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]:
"""Return the cost of a cell obtained by the potentials of the optimal transport solution."""
sol = list(self.problems.values())[0].solution
if TYPE_CHECKING:
assert isinstance(sol, BaseSolverOutput)
if sol.potentials is None:
return None
df_list = [
pd.DataFrame(
np.array(np.abs(problem.solution.potentials[0])), # type: ignore[union-attr,index]
index=problem.adata_src.obs_names,
columns=["cell_cost_source"],
)
for problem in self.problems.values()
]
tup = list(self)[-1]
df_list.append(
pd.DataFrame(
np.full(shape=(len(self.problems[tup].adata_tgt.obs), 1), fill_value=np.nan),
index=self.problems[tup].adata_tgt.obs_names,
columns=["cell_cost_source"],
)
)
return pd.concat(df_list, verify_integrity=True)

@property
def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]:
"""Return the cost of a cell (see online methods) obtained by the potentials of the OT solution."""
sol = list(self.problems.values())[0].solution
if TYPE_CHECKING:
assert isinstance(sol, BaseSolverOutput)
if sol.potentials is None:
return None

tup = list(self)[0]
df_list = [
pd.DataFrame(
np.full(shape=(len(self.problems[tup].adata_src), 1), fill_value=np.nan),
index=self.problems[tup].adata_src.obs_names,
columns=["cell_cost_target"],
)
]
df_list.extend(
[
pd.DataFrame(
np.array(np.abs(problem.solution.potentials[1])), # type: ignore[union-attr,index]
index=problem.adata_tgt.obs_names,
columns=["cell_cost_target"],
)
for problem in self.problems.values()
]
)
return pd.concat(df_list, verify_integrity=True)

# TODO(michalk8): refactor me
def _get_data(
self: TemporalMixinProtocol[K, B],
Expand Down
2 changes: 1 addition & 1 deletion tests/problems/base/test_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_sc_pipeline(self, adata_time: AnnData):
key="time",
policy="sequential",
)
problem = problem.solve()
problem = problem.solve(max_iterations=2)

assert len(problem) == len(expected_keys)
assert isinstance(problem.solutions, dict)
Expand Down
Loading

0 comments on commit f152537

Please sign in to comment.