From f152537fe3b3e899cb9da5305654ffe1a1d87fe1 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Sun, 5 Feb 2023 19:03:19 +0100 Subject: [PATCH] Tests/spatiotemporalproblem (#464) * add more tests for spatiotemporalProblem * move some functions from TemporalProblem to TemporalMixin * add tests LineageProblem * fix tests --- .../spatio_temporal/_spatio_temporal.py | 21 ++- src/moscot/problems/time/_lineage.py | 113 +----------- src/moscot/problems/time/_mixins.py | 112 +++++++++++- tests/problems/base/test_compound_problem.py | 2 +- .../test_spatio_temporal_problem.py | 71 ++++++++ tests/problems/time/test_lineage_problem.py | 163 +++++++++++++++++- tests/problems/time/test_mixins.py | 2 +- tests/problems/time/test_temporal_problem.py | 6 +- 8 files changed, 367 insertions(+), 123 deletions(-) diff --git a/src/moscot/problems/spatio_temporal/_spatio_temporal.py b/src/moscot/problems/spatio_temporal/_spatio_temporal.py index ee1904991..e39890cac 100644 --- a/src/moscot/problems/spatio_temporal/_spatio_temporal.py +++ b/src/moscot/problems/spatio_temporal/_spatio_temporal.py @@ -21,7 +21,16 @@ class SpatioTemporalProblem( AlignmentProblem[Numeric_t, BirthDeathProblem], SpatialAlignmentMixin[Numeric_t, BirthDeathProblem], ): - """Spatio-Temporal problem.""" + """ + Class for analyzing time series spatial single cell data. + + The `SpatioTemporalProblem` allows to model and analyze spatio-temporal single cell data + by matching cells belonging to two different time points via OT. + + Parameters + ---------- + %(adata)s + """ def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) @@ -39,6 +48,7 @@ def prepare( ] = "sq_euclidean", a: Optional[str] = None, b: Optional[str] = None, + marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "SpatioTemporalProblem": """ @@ -73,6 +83,14 @@ def prepare( # spatial key set in AlignmentProblem # handle_joint_attr and handle_cost in AlignmentProblem self.temporal_key = time_key + # TODO(michalk8): needs to be modified, move into BirthDeathMixin? + marginal_kwargs = dict(marginal_kwargs) + marginal_kwargs["proliferation_key"] = self.proliferation_key + marginal_kwargs["apoptosis_key"] = self.apoptosis_key + if a is None: + a = self.proliferation_key is not None or self.apoptosis_key is not None + if b is None: + b = self.proliferation_key is not None or self.apoptosis_key is not None return super().prepare( spatial_key=spatial_key, @@ -83,6 +101,7 @@ def prepare( cost=cost, a=a, b=b, + marginal_kwargs=marginal_kwargs, **kwargs, ) diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index c0736eabb..a472ad0a9 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -1,16 +1,11 @@ from types import MappingProxyType -from typing import Any, Type, Tuple, Union, Literal, Mapping, Optional, TYPE_CHECKING - -import pandas as pd - -import numpy as np +from typing import Any, Type, Tuple, Union, Literal, Mapping, Optional from anndata import AnnData from moscot._types import Numeric_t, ScaleCost_t, ProblemStage_t, QuadInitializer_t, SinkhornInitializer_t from moscot._docs._docs import d from moscot.problems._utils import handle_cost, handle_joint_attr -from moscot.solvers._output import BaseSolverOutput from moscot._constants._constants import Policy from moscot.problems.time._mixins import TemporalMixin from moscot.problems.base._birth_death import BirthDeathMixin, BirthDeathProblem @@ -22,10 +17,10 @@ class TemporalProblem( TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem] ): """ - Class for analysing time series single cell data based on :cite:`schiebinger:19`. + Class for analyzing time series single cell data based on :cite:`schiebinger:19`. - The `TemporalProblem` allows to model and analyse time series single cell data by matching - cells from previous time points to later time points via optimal transport. + The `TemporalProblem` allows to model and analyze time series single cell data by matching + cells from previous time points to later time points via OT. Based on the assumption that the considered cell modality is similar in consecutive time points probabilistic couplings are computed between different time points. This allows to understand cell trajectories by inferring ancestors and descendants of single cells. @@ -184,106 +179,6 @@ def solve( **kwargs, ) # type:ignore[return-value] - @property - def prior_growth_rates(self) -> Optional[pd.DataFrame]: - """Return the prior estimate of growth rates of the cells in the source distribution.""" - # TODO(michalk8): FIXME - cols = ["prior_growth_rates"] - df_list = [ - pd.DataFrame(problem.prior_growth_rates, index=problem.adata.obs.index, columns=cols) - for problem in self.problems.values() - ] - tup = list(self)[-1] - df_list.append( - pd.DataFrame( - np.full( - shape=(len(self.problems[tup].adata_tgt.obs), 1), - fill_value=np.nan, - ), - index=self.problems[tup].adata_tgt.obs.index, - columns=cols, - ) - ) - return pd.concat(df_list, verify_integrity=True) - - @property - def posterior_growth_rates(self) -> Optional[pd.DataFrame]: - """Return the posterior estimate of growth rates of the cells in the source distribution.""" - # TODO(michalk8): FIXME - cols = ["posterior_growth_rates"] - df_list = [ - pd.DataFrame(problem.posterior_growth_rates, index=problem.adata.obs.index, columns=cols) - for problem in self.problems.values() - ] - tup = list(self)[-1] - df_list.append( - pd.DataFrame( - np.full( - shape=(len(self.problems[tup].adata_tgt.obs), 1), - fill_value=np.nan, - ), - index=self.problems[tup].adata_tgt.obs.index, - columns=cols, - ) - ) - return pd.concat(df_list, verify_integrity=True) - - # TODO(michalk8): refactor me - @property - def cell_costs_source(self) -> Optional[pd.DataFrame]: - """Return the cost of a cell obtained by the potentials of the optimal transport solution.""" - sol = list(self.problems.values())[0].solution - if TYPE_CHECKING: - assert isinstance(sol, BaseSolverOutput) - if sol.potentials is None: - return None - df_list = [ - pd.DataFrame( - np.array(np.abs(problem.solution.potentials[0])), # type: ignore[union-attr,index] - index=problem.adata_src.obs_names, - columns=["cell_cost_source"], - ) - for problem in self.problems.values() - ] - tup = list(self)[-1] - df_list.append( - pd.DataFrame( - np.full(shape=(len(self.problems[tup].adata_tgt.obs), 1), fill_value=np.nan), - index=self.problems[tup].adata_tgt.obs_names, - columns=["cell_cost_source"], - ) - ) - return pd.concat(df_list, verify_integrity=True) - - @property - def cell_costs_target(self) -> Optional[pd.DataFrame]: - """Return the cost of a cell (see online methods) obtained by the potentials of the OT solution.""" - sol = list(self.problems.values())[0].solution - if TYPE_CHECKING: - assert isinstance(sol, BaseSolverOutput) - if sol.potentials is None: - return None - - tup = list(self)[0] - df_list = [ - pd.DataFrame( - np.full(shape=(len(self.problems[tup].adata_src), 1), fill_value=np.nan), - index=self.problems[tup].adata_src.obs_names, - columns=["cell_cost_target"], - ) - ] - df_list.extend( - [ - pd.DataFrame( - np.array(np.abs(problem.solution.potentials[1])), # type: ignore[union-attr,index] - index=problem.adata_tgt.obs_names, - columns=["cell_cost_target"], - ) - for problem in self.problems.values() - ] - ) - return pd.concat(df_list, verify_integrity=True) - @property def _base_problem_type(self) -> Type[B]: # type: ignore[override] return BirthDeathProblem # type: ignore[return-value] diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 063d32b7c..e3e244ff8 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Union, Literal, Iterable, Optional, Protocol, TYPE_CHECKING +from typing import Any, Dict, List, Tuple, Union, Literal, Iterable, Iterator, Optional, Protocol, TYPE_CHECKING from pathlib import Path import itertools @@ -10,18 +10,21 @@ from anndata import AnnData from moscot._types import ArrayLike, Numeric_t, Str_Dict_t +from moscot.solvers._output import BaseSolverOutput from moscot._docs._docs_mixins import d_mixins from moscot._constants._constants import Key, PlottingKeys, PlottingDefaults from moscot.problems.base._mixins import AnalysisMixin, AnalysisMixinProtocol from moscot.solvers._tagged_array import Tag +from moscot.problems.base._birth_death import BirthDeathProblem from moscot.problems.base._compound_problem import B, K, ApplyOutput_t -class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): +# TODO(@MUCDK, @michalk8): check for ignore[misc] in line below, might become redundant +class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type:ignore """Protocol class.""" adata: AnnData - problems: Dict[Tuple[K, K], B] + problems: Dict[Tuple[K, K], BirthDeathProblem] temporal_key: Optional[str] _temporal_key: Optional[str] @@ -134,6 +137,9 @@ def _get_interp_param( ) -> Numeric_t: ... + def __iter__(self) -> Iterator[Tuple[K, K]]: + ... + class TemporalMixin(AnalysisMixin[K, B]): """Analysis Mixin for all problems involving a temporal dimension.""" @@ -408,6 +414,106 @@ def pull( if return_data: return result + @property + def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + """Return the prior estimate of growth rates of the cells in the source distribution.""" + # TODO(michalk8): FIXME + cols = ["prior_growth_rates"] + df_list = [ + pd.DataFrame(problem.prior_growth_rates, index=problem.adata.obs.index, columns=cols) + for problem in self.problems.values() + ] + tup = list(self)[-1] + df_list.append( + pd.DataFrame( + np.full( + shape=(len(self.problems[tup].adata_tgt.obs), 1), + fill_value=np.nan, + ), + index=self.problems[tup].adata_tgt.obs.index, + columns=cols, + ) + ) + return pd.concat(df_list, verify_integrity=True) + + @property + def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + """Return the posterior estimate of growth rates of the cells in the source distribution.""" + # TODO(michalk8): FIXME + cols = ["posterior_growth_rates"] + df_list = [ + pd.DataFrame(problem.posterior_growth_rates, index=problem.adata.obs.index, columns=cols) + for problem in self.problems.values() + ] + tup = list(self)[-1] + df_list.append( + pd.DataFrame( + np.full( + shape=(len(self.problems[tup].adata_tgt.obs), 1), + fill_value=np.nan, + ), + index=self.problems[tup].adata_tgt.obs.index, + columns=cols, + ) + ) + return pd.concat(df_list, verify_integrity=True) + + # TODO(michalk8): refactor me + @property + def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + """Return the cost of a cell obtained by the potentials of the optimal transport solution.""" + sol = list(self.problems.values())[0].solution + if TYPE_CHECKING: + assert isinstance(sol, BaseSolverOutput) + if sol.potentials is None: + return None + df_list = [ + pd.DataFrame( + np.array(np.abs(problem.solution.potentials[0])), # type: ignore[union-attr,index] + index=problem.adata_src.obs_names, + columns=["cell_cost_source"], + ) + for problem in self.problems.values() + ] + tup = list(self)[-1] + df_list.append( + pd.DataFrame( + np.full(shape=(len(self.problems[tup].adata_tgt.obs), 1), fill_value=np.nan), + index=self.problems[tup].adata_tgt.obs_names, + columns=["cell_cost_source"], + ) + ) + return pd.concat(df_list, verify_integrity=True) + + @property + def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + """Return the cost of a cell (see online methods) obtained by the potentials of the OT solution.""" + sol = list(self.problems.values())[0].solution + if TYPE_CHECKING: + assert isinstance(sol, BaseSolverOutput) + if sol.potentials is None: + return None + + tup = list(self)[0] + df_list = [ + pd.DataFrame( + np.full(shape=(len(self.problems[tup].adata_src), 1), fill_value=np.nan), + index=self.problems[tup].adata_src.obs_names, + columns=["cell_cost_target"], + ) + ] + df_list.extend( + [ + pd.DataFrame( + np.array(np.abs(problem.solution.potentials[1])), # type: ignore[union-attr,index] + index=problem.adata_tgt.obs_names, + columns=["cell_cost_target"], + ) + for problem in self.problems.values() + ] + ) + return pd.concat(df_list, verify_integrity=True) + # TODO(michalk8): refactor me def _get_data( self: TemporalMixinProtocol[K, B], diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index b11eb7d3a..48e18a50e 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -55,7 +55,7 @@ def test_sc_pipeline(self, adata_time: AnnData): key="time", policy="sequential", ) - problem = problem.solve() + problem = problem.solve(max_iterations=2) assert len(problem) == len(expected_keys) assert isinstance(problem.solutions, dict) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a638057dd..c1c22ffc7 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -1,11 +1,13 @@ from typing import Any, List, Mapping +import pandas as pd import pytest import numpy as np from anndata import AnnData +from tests._utils import ATOL, RTOL from moscot.solvers._output import BaseSolverOutput from tests.problems.conftest import ( fgw_args_1, @@ -110,6 +112,75 @@ def test_score_genes(self, adata_spatio_temporal: AnnData, gene_set_list: List[L else: assert problem.apoptosis_key is None + @pytest.mark.fast() + def test_proliferation_key_pipeline(self, adata_spatio_temporal: AnnData): + problem = SpatioTemporalProblem(adata_spatio_temporal) + assert problem.proliferation_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.proliferation_key == "proliferation" + + adata_spatio_temporal.obs["new_proliferation"] = np.ones(adata_spatio_temporal.n_obs) + problem.proliferation_key = "new_proliferation" + assert problem.proliferation_key == "new_proliferation" + + @pytest.mark.fast() + def test_apoptosis_key_pipeline(self, adata_spatio_temporal: AnnData): + problem = SpatioTemporalProblem(adata_spatio_temporal) + assert problem.apoptosis_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.apoptosis_key == "apoptosis" + + adata_spatio_temporal.obs["new_apoptosis"] = np.ones(adata_spatio_temporal.n_obs) + problem.apoptosis_key = "new_apoptosis" + assert problem.apoptosis_key == "new_apoptosis" + + @pytest.mark.fast() + @pytest.mark.parametrize("scaling", [0.1, 1, 4]) + def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scaling: float): + keys = np.sort(np.unique(adata_spatio_temporal.obs["time"].values)) + adata_spatio_temporal = adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([keys[0], keys[1]])] + delta = keys[1] - keys[0] + problem = SpatioTemporalProblem(adata_spatio_temporal) + assert problem.proliferation_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.proliferation_key == "proliferation" + + problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling}) + prolif = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["proliferation"] + apopt = adata_spatio_temporal[adata_spatio_temporal.obs["time"] == keys[0]].obs["apoptosis"] + expected_marginals = np.exp((prolif - apopt) * delta / scaling) + print("problem[keys[0], keys[1]]._prior_growth", problem[keys[0], keys[1]]._prior_growth) + print("expected_marginals", expected_marginals) + np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + + def test_growth_rates_pipeline(self, adata_spatio_temporal: AnnData): + problem = SpatioTemporalProblem(adata=adata_spatio_temporal) + problem = problem.score_genes_for_marginals(gene_set_proliferation="mouse", gene_set_apoptosis="mouse") + problem = problem.prepare("time", a=True, b=True) + problem = problem.solve(max_iterations=2) + + growth_rates = problem.posterior_growth_rates + assert isinstance(growth_rates, pd.DataFrame) + assert len(growth_rates.columns) == 1 + assert set(growth_rates.index) == set(adata_spatio_temporal.obs.index) + assert set(growth_rates[growth_rates["posterior_growth_rates"].isnull()].index) == set( + adata_spatio_temporal[adata_spatio_temporal.obs["time"] == 2].obs.index + ) + assert set(growth_rates[~growth_rates["posterior_growth_rates"].isnull()].index) == set( + adata_spatio_temporal[adata_spatio_temporal.obs["time"].isin([0, 1])].obs.index + ) + + def test_cell_costs_pipeline(self, adata_spatio_temporal: AnnData): + problem = SpatioTemporalProblem(adata=adata_spatio_temporal) + problem = problem.prepare("time") + problem = problem.solve(max_iterations=2) + + assert problem.cell_costs_source is None + assert problem.cell_costs_target is None + @pytest.mark.parametrize("args_to_check", [fgw_args_1, fgw_args_2]) def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Mapping[str, Any]): problem = SpatioTemporalProblem(adata=adata_spatio_temporal) diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 6cbc57fac..e0aa7ad8b 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -1,10 +1,14 @@ -from typing import Any, Mapping +from typing import Any, List, Mapping import pytest +import numpy as np + from anndata import AnnData +from tests._utils import ATOL, RTOL from moscot.problems.time import LineageProblem +from moscot.solvers._output import BaseSolverOutput from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -19,6 +23,155 @@ class TestLineageProblem: + @pytest.mark.fast() + def test_prepare(self, adata_time_barcodes: AnnData): + expected_keys = [(0, 1), (1, 2)] + problem = LineageProblem(adata=adata_time_barcodes) + assert len(problem) == 0 + assert problem.problems == {} + assert problem.solutions == {} + problem = problem.prepare( + time_key="time", + policy="sequential", + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + ) + + assert isinstance(problem.problems, dict) + assert len(problem.problems) == len(expected_keys) + + for key in problem: + assert key in expected_keys + assert isinstance(problem[key], BirthDeathProblem) + + def test_solve_balanced(self, adata_time_barcodes: AnnData): + eps = 0.5 + expected_keys = [(0, 1)] + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin((0, 1))] + problem = LineageProblem(adata=adata_time_barcodes) + problem = problem.prepare( + time_key="time", + policy="sequential", + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + ) + problem = problem.solve(epsilon=eps) + + for key, subsol in problem.solutions.items(): + assert isinstance(subsol, BaseSolverOutput) + assert key in expected_keys + + def test_solve_unbalanced(self, adata_time_barcodes: AnnData): + taus = [9e-1, 1e-2] + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin((0, 1))] + + problem1 = LineageProblem(adata=adata_time_barcodes) + problem1 = problem1.prepare( + time_key="time", + policy="sequential", + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + ) + problem2 = LineageProblem(adata=adata_time_barcodes) + problem2 = problem2.prepare( + time_key="time", + policy="sequential", + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + ) + + assert problem1[0, 1].a is not None + assert problem1[0, 1].b is not None + assert problem2[0, 1].a is not None + assert problem2[0, 1].b is not None + + problem1 = problem1.solve(epsilon=1, tau_a=taus[0], tau_b=taus[0], max_iterations=100) + problem2 = problem2.solve(epsilon=1, tau_a=taus[1], tau_b=taus[1], max_iterations=100) + + assert problem1[0, 1].solution.a is not None + assert problem1[0, 1].solution.b is not None + assert problem2[0, 1].solution.a is not None + assert problem2[0, 1].solution.b is not None + + div1 = np.linalg.norm(problem1[0, 1].a - problem1[0, 1].solution.a) + div2 = np.linalg.norm(problem1[0, 1].b - problem1[0, 1].solution.b) + assert div1 < div2 + + @pytest.mark.fast() + @pytest.mark.parametrize( + "gene_set_list", + [ + [None, None], + ["human", "human"], + ["mouse", "mouse"], + [["ANLN", "ANP32E", "ATAD2"], ["ADD1", "AIFM3", "ANKH"]], + ], + ) + def test_score_genes(self, adata_time_barcodes: AnnData, gene_set_list: List[List[str]]): + gene_set_proliferation = gene_set_list[0] + gene_set_apoptosis = gene_set_list[1] + problem = LineageProblem(adata_time_barcodes) + problem.score_genes_for_marginals( + gene_set_proliferation=gene_set_proliferation, gene_set_apoptosis=gene_set_apoptosis + ) + + if gene_set_apoptosis is not None: + assert problem.proliferation_key == "proliferation" + assert adata_time_barcodes.obs["proliferation"] is not None + assert np.sum(np.isnan(adata_time_barcodes.obs["proliferation"])) == 0 + else: + assert problem.proliferation_key is None + + if gene_set_apoptosis is not None: + assert problem.apoptosis_key == "apoptosis" + assert adata_time_barcodes.obs["apoptosis"] is not None + assert np.sum(np.isnan(adata_time_barcodes.obs["apoptosis"])) == 0 + else: + assert problem.apoptosis_key is None + + @pytest.mark.fast() + def test_proliferation_key_pipeline(self, adata_time_barcodes: AnnData): + problem = LineageProblem(adata_time_barcodes) + assert problem.proliferation_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.proliferation_key == "proliferation" + + adata_time_barcodes.obs["new_proliferation"] = np.ones(adata_time_barcodes.n_obs) + problem.proliferation_key = "new_proliferation" + assert problem.proliferation_key == "new_proliferation" + + @pytest.mark.fast() + def test_apoptosis_key_pipeline(self, adata_time_barcodes: AnnData): + problem = LineageProblem(adata_time_barcodes) + assert problem.apoptosis_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.apoptosis_key == "apoptosis" + + adata_time_barcodes.obs["new_apoptosis"] = np.ones(adata_time_barcodes.n_obs) + problem.apoptosis_key = "new_apoptosis" + assert problem.apoptosis_key == "new_apoptosis" + + @pytest.mark.fast() + @pytest.mark.parametrize("scaling", [0.1, 1, 4]) + def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scaling: float): + keys = np.sort(np.unique(adata_time_barcodes.obs["time"].values)) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin([keys[0], keys[1]])] + delta = keys[1] - keys[0] + problem = LineageProblem(adata_time_barcodes) + assert problem.proliferation_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.proliferation_key == "proliferation" + + problem = problem.prepare( + time_key="time", + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + policy="sequential", + marginal_kwargs={"scaling": scaling}, + ) + prolif = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["proliferation"] + apopt = adata_time_barcodes[adata_time_barcodes.obs["time"] == keys[0]].obs["apoptosis"] + expected_marginals = np.exp((prolif - apopt) * delta / scaling) + np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + @pytest.mark.fast() def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): expected_keys = [(0, 1), (1, 2)] @@ -28,7 +181,7 @@ def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, policy="sequential", ) - problem = problem.solve() + problem = problem.solve(max_iterations=2) for key in problem: assert key in expected_keys @@ -38,7 +191,7 @@ def test_custom_cost_pipeline(self, adata_time_custom_cost_xy: AnnData): expected_keys = [(0, 1), (1, 2)] problem = LineageProblem(adata=adata_time_custom_cost_xy) problem = problem.prepare(time_key="time") - problem = problem.solve() + problem = problem.solve(max_iterations=2) for key in problem: assert key in expected_keys @@ -50,7 +203,7 @@ def test_trees_pipeline(self, adata_time_trees: AnnData): problem = problem.prepare( time_key="time", lineage_attr={"attr": "uns", "key": "trees", "tag": "cost_matrix", "cost": "leaf_distance"} ) - problem = problem.solve(max_iterations=10) + problem = problem.solve(max_iterations=2) for key in problem: assert key in expected_keys @@ -59,7 +212,7 @@ def test_trees_pipeline(self, adata_time_trees: AnnData): def test_cell_costs_pipeline(self, adata_time_custom_cost_xy: AnnData): problem = LineageProblem(adata=adata_time_custom_cost_xy) problem = problem.prepare("time") - problem = problem.solve(max_iterations=10) + problem = problem.solve(max_iterations=1) assert problem.cell_costs_source is None assert problem.cell_costs_target is None diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index b6b76a711..c7fd0695c 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -304,7 +304,7 @@ def test_get_interp_param_pipeline(self, adata_time: AnnData, time_points: Tuple interpolation_parameter = None if len(time_points) == 3 else 0.5 problem = TemporalProblem(adata_time) problem.prepare("time") - problem.solve() + problem.solve(max_iterations=2) if intermediate <= start or end <= intermediate: with np.testing.assert_raises(ValueError): diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 4d0aa22b5..eab8738a6 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -157,7 +157,7 @@ def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float) def test_cell_costs_source_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata=adata_time) problem = problem.prepare("time") - problem = problem.solve() + problem = problem.solve(max_iterations=2) cell_costs_source = problem.cell_costs_source @@ -175,7 +175,7 @@ def test_cell_costs_source_pipeline(self, adata_time: AnnData): def test_cell_costs_target_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata=adata_time) problem = problem.prepare("time") - problem = problem.solve() + problem = problem.solve(max_iterations=2) cell_costs_target = problem.cell_costs_target @@ -194,7 +194,7 @@ def test_growth_rates_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata=adata_time) problem = problem.score_genes_for_marginals(gene_set_proliferation="mouse", gene_set_apoptosis="mouse") problem = problem.prepare("time", a=True, b=True) - problem = problem.solve() + problem = problem.solve(max_iterations=2) growth_rates = problem.posterior_growth_rates assert isinstance(growth_rates, pd.DataFrame)