diff --git a/docs/source/api/developer.rst b/docs/source/api/developer.rst index 2d14753a3..8b013ec55 100644 --- a/docs/source/api/developer.rst +++ b/docs/source/api/developer.rst @@ -10,7 +10,6 @@ OTT Backend moscot.backends.ott.SinkhornSolver moscot.backends.ott.GWSolver - moscot.backends.ott.FGWSolver moscot.backends.ott.OTTOutput moscot.backends.ott.OTTCost diff --git a/docs/source/api/user.rst b/docs/source/api/user.rst index c7e674cf4..8924defcc 100644 --- a/docs/source/api/user.rst +++ b/docs/source/api/user.rst @@ -28,4 +28,3 @@ Generic Problems SinkhornProblem GWProblem - FGWProblem diff --git a/docs/source/conf.py b/docs/source/conf.py index c3b329f5e..085c221be 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -49,6 +49,7 @@ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx.ext.viewcode", + "sphinx.ext.mathjax", "sphinx_autodoc_typehints", "sphinx.ext.intersphinx", "sphinx.ext.autosummary", diff --git a/docs/source/extensions/typed_returns.py b/docs/source/extensions/typed_returns.py index a18e9933f..d5126910f 100644 --- a/docs/source/extensions/typed_returns.py +++ b/docs/source/extensions/typed_returns.py @@ -10,7 +10,10 @@ def process_return(lines: Iterable[str]) -> Iterator[str]: m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) if m: # Once this is in scanpydoc, we can use the fancy hover stuff - yield f'**{m["param"]}** : :class:`~{m["type"]}`' + if m["param"]: + yield f'**{m["param"]}** : :class:`~{m["type"]}`' + else: + yield f':class:`~{m["type"]}`' else: yield line diff --git a/src/moscot/_docs/_docs.py b/src/moscot/_docs/_docs.py index 29727d604..a4ffa25e5 100644 --- a/src/moscot/_docs/_docs.py +++ b/src/moscot/_docs/_docs.py @@ -85,7 +85,7 @@ data - If `data` is a :class:`str` this should correspond to a column in :attr:`anndata.AnnData.obs`. The transport map is applied to the subset corresponding to the source distribution - (if `forward` is `True`) or target distribution (if `forward` is `False`) of that column. + (if `forward` is `True`) or target distribution (if `forward` is :obj:`False`) of that column. - If `data` is a :class:npt.ArrayLike the transport map is applied to `data`. - If `data` is a :class:`dict` then the keys should correspond to the tuple defining a single optimal transport map and the value should be one of the two cases described above. @@ -94,18 +94,20 @@ subset Subset of :attr:`anndata.AnnData.obs` ``['{key}']`` values of which the policy is to be applied to. """ -_marginal_kwargs = """\ +_marginal_kwargs = r""" marginal_kwargs - keyword arguments for :meth:`moscot.problems.BirthDeathProblem._estimate_marginals`, i.e. - for modeling the birth-death process. The keyword arguments - are either used for :func:`moscot.problems.time._utils.beta`, i.e. one of: + Keyword arguments for :meth:`~moscot.problems.BirthDeathProblem._estimate_marginals`. If ``'scaling'`` + is in ``marginal_kwargs``, the left marginals are computed as + :math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})`. + Otherwise, the left marginals are computed using a birth-death process. The keyword arguments + are either used for :func:`~moscot.problems.time._utils.beta`, i.e. one of: - beta_max: float - beta_min: float - beta_center: float - beta_width: float - or for :func:`moscot.problems.time._utils.beta`, i.e. one of: + or for :func:`~moscot.problems.time._utils.delta`, i.e. one of: - delta_max: float - delta_min: float @@ -127,13 +129,35 @@ _a = """\ a Specifies the left marginals. If of type :class:`str` the left marginals are taken from - :attr:`anndata.AnnData.obs` ``['{a}']``. If `a` is `None` uniform marginals are used. + :attr:`anndata.AnnData.obs` ``['{a}']``. If ``a`` is `None` uniform marginals are used. """ _b = """\ b Specifies the right marginals. If of type :class:`str` the right marginals are taken from :attr:`anndata.AnnData.obs` ``['{b}']``. If `b` is `None` uniform marginals are used. """ +_a_temporal = r""" +a + Specifies the left marginals. If + - ``a`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`, + - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and + if ``a`` is `None`, marginals are computed based on a birth-death process as suggested in + :cite:`schiebinger:19`, + - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and + if ``a`` is `None`, and additionally ``'scaling'`` is provided in ``marginal_kwargs``, + the marginals are computed as + :math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})` + rather than using a birth-death process, + - otherwise or if ``a`` is :obj:`False`, uniform marginals are used. +""" +_b_temporal = """\ +b + Specifies the right marginals. If + - ``b`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`, + - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run + uniform (mean of left marginals) right marginals are used, + - otherwise or if ``b`` is :obj:`False`, uniform marginals are used. +""" _time_key = """\ time_key Time point key in :attr:`anndata.AnnData.obs`. @@ -400,6 +424,8 @@ converged=_converged, a=_a, b=_b, + a_temporal=_a_temporal, + b_temporal=_b_temporal, time_key=_time_key, spatial_key=_spatial_key, batch_key=_batch_key, diff --git a/src/moscot/_docs/_docs_mixins.py b/src/moscot/_docs/_docs_mixins.py index 3af941609..89e5b1f14 100644 --- a/src/moscot/_docs/_docs_mixins.py +++ b/src/moscot/_docs/_docs_mixins.py @@ -55,10 +55,7 @@ for the corresponding plotting functions are stored. See TODO Notebook for how :mod:`moscot.plotting` works. """ -_return_cell_transition = """\ -retun_cell_transition - Transition matrix of cells or groups of cells. -""" +_return_cell_transition = "Transition matrix of cells or groups of cells." _notes_cell_transition = """\ To visualise the results, see :func:`moscot.pl.cell_transition`. """ diff --git a/src/moscot/problems/base/_birth_death.py b/src/moscot/problems/base/_birth_death.py index fbfb8db05..96c00a64b 100644 --- a/src/moscot/problems/base/_birth_death.py +++ b/src/moscot/problems/base/_birth_death.py @@ -1,4 +1,5 @@ -from typing import Any, Union, Literal, Callable, Optional, Protocol, Sequence, TYPE_CHECKING +from types import MappingProxyType +from typing import Any, Union, Literal, Mapping, Callable, Optional, Protocol, Sequence, TYPE_CHECKING import numpy as np @@ -68,7 +69,7 @@ def score_genes_for_marginals( to proliferation and/or apoptosis must be passed. Alternatively, proliferation and apoptosis genes for humans and mice are saved in :mod:`moscot`. - The gene scores will be used in :meth:`moscot.problems.TemporalProblem.prepare` to estimate the initial + The gene scores will be used in :meth:`~moscot.problems.CompoundBaseProblem.prepare` to estimate the initial growth rates as suggested in :cite:`schiebinger:19` Parameters @@ -175,9 +176,10 @@ def _estimate_marginals( source: bool, proliferation_key: Optional[str] = None, apoptosis_key: Optional[str] = None, - **kwargs: Any, + marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), + **_: Any, ) -> ArrayLike: - def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike: + def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) -> ArrayLike: if key is None: return np.zeros(adata.n_obs, dtype=float) try: @@ -189,16 +191,22 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike: raise ValueError("Either `proliferation_key` or `apoptosis_key` must be specified.") self.proliferation_key = proliferation_key self.apoptosis_key = apoptosis_key + if "scaling" in marginal_kwargs: + beta_fn = delta_fn = lambda x, *args, **kwargs: x + scaling = marginal_kwargs["scaling"] + else: + beta_fn, delta_fn = beta, delta + scaling = 1 + birth = estimate(proliferation_key, fn=beta_fn, **marginal_kwargs) + death = estimate(apoptosis_key, fn=delta_fn, **marginal_kwargs) + + prior_growth = np.exp((birth - death) * self.delta / scaling) - birth = estimate(proliferation_key, fn=beta) - death = estimate(apoptosis_key, fn=delta) - prior_growth = np.exp((birth - death) * self.delta) scaling = np.sum(prior_growth) normalized_growth = prior_growth / scaling if source: self._scaling = scaling self._prior_growth = prior_growth - return normalized_growth if source else np.full(self.adata_tgt.n_obs, fill_value=np.mean(normalized_growth)) # TODO(michalk8): temporary fix to satisfy the mixin, consider removing the mixin diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index 5b11a53e6..3d24b1c62 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -351,7 +351,7 @@ def _( return res if return_all else current_mass @d.get_sections(base="BaseCompoundProblem_push", sections=["Parameters", "Raises"]) - @d.dedent + @d.dedent # TODO(@MUCDK) document private _apply def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: """ Push mass from `start` to `end`. @@ -371,7 +371,7 @@ def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: %(scale_by_marginals)s kwargs - keyword arguments for :meth:`moscot.problems.CompoundProblem._apply`. + keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`. Returns ------- @@ -382,7 +382,7 @@ def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: return self._apply(*args, forward=True, **kwargs) @d.get_sections(base="BaseCompoundProblem_pull", sections=["Parameters", "Raises"]) - @d.dedent + @d.dedent # TODO(@MUCDK) document private functions def pull(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: """ Pull mass from `end` to `start`. @@ -402,7 +402,7 @@ def pull(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: %(scale_by_marginals)s kwargs - Keyword arguments for :meth:`moscot.problems.CompoundProblem._apply`. + keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`. Returns ------- diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index a75fba7fe..c0736eabb 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -48,6 +48,7 @@ def prepare( cost: Literal["sq_euclidean", "cosine", "bures", "unbalanced_bures"] = "sq_euclidean", a: Optional[str] = None, b: Optional[str] = None, + marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "TemporalProblem": """ @@ -61,8 +62,9 @@ def prepare( %(joint_attr)s %(policy)s %(cost_lin)s - %(a)s - %(b)s + %(a_temporal)s + %(b_temporal)s + %(marginal_kwargs)s %(kwargs_prepare)s @@ -84,7 +86,7 @@ def prepare( xy, x, y = handle_cost(xy=xy, x=kwargs.pop("x", None), y=kwargs.pop("y", None), cost=cost) # TODO(michalk8): needs to be modified - marginal_kwargs = dict(kwargs.pop("marginal_kwargs", {})) + marginal_kwargs = dict(marginal_kwargs) marginal_kwargs["proliferation_key"] = self.proliferation_key marginal_kwargs["apoptosis_key"] = self.apoptosis_key if a is None: @@ -316,6 +318,7 @@ def prepare( ] = "sq_euclidean", a: Optional[str] = None, b: Optional[str] = None, + marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "LineageProblem": """ @@ -331,8 +334,9 @@ def prepare( %(joint_attr)s %(policy)s %(cost)s - %(a)s - %(b)s + %(a_temporal)s + %(b_temporal)s + %(marginal_kwargs)s %(kwargs_prepare)s Returns @@ -364,6 +368,7 @@ def prepare( cost=cost, a=a, b=b, + marginal_kwargs=marginal_kwargs, **kwargs, ) diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 12a11e118..39f71d47e 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -2,6 +2,7 @@ from pathlib import Path import itertools +from pandas.api.types import infer_dtype, is_numeric_dtype import pandas as pd import numpy as np @@ -459,6 +460,7 @@ def _get_data( target_data, ) + @d_mixins.dedent def compute_interpolated_distance( self: TemporalMixinProtocol[K, B], source: K, @@ -502,7 +504,7 @@ def compute_interpolated_distance( %(use_posterior_marginals)s %(seed_sampling)s %(backend)s - %(kwargs_divergence) + %(kwargs_divergence)s Returns ------- @@ -534,6 +536,7 @@ def compute_interpolated_distance( point_cloud_1=intermediate_data, point_cloud_2=interpolation, backend=backend, **kwargs ) + @d_mixins.dedent def compute_random_distance( self: TemporalMixinProtocol[K, B], source: K, @@ -544,7 +547,7 @@ def compute_random_distance( account_for_unbalancedness: bool = False, posterior_marginals: bool = True, seed: Optional[int] = None, - backend: Literal["ott"] = "ott", + backend: Literal["ott"] = "ott", # TODO: not used **kwargs: Any, ) -> Numeric_t: """ @@ -566,7 +569,7 @@ def compute_random_distance( %(use_posterior_marginals)s %(seed_interpolation)s %(backend)s - %(kwargs_divergence) + %(kwargs_divergence)s Returns ------- @@ -592,6 +595,7 @@ def compute_random_distance( ) return self._compute_wasserstein_distance(intermediate_data, random_interpolation, **kwargs) + @d_mixins.dedent def compute_time_point_distances( self: TemporalMixinProtocol[K, B], source: K, @@ -770,6 +774,11 @@ def temporal_key(self) -> Optional[str]: def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.") + if key is not None and not is_numeric_dtype(self.adata.obs[key]): + raise TypeError( + "Temporal key has to be of numeric type." + f"Found `adata.obs[{key!r}]` to be of type `{infer_dtype(self.adata.obs[key])}`." + ) self._temporal_key = key diff --git a/src/moscot/problems/time/_utils.py b/src/moscot/problems/time/_utils.py index 8e2679da9..44553023e 100644 --- a/src/moscot/problems/time/_utils.py +++ b/src/moscot/problems/time/_utils.py @@ -11,30 +11,30 @@ def logistic(x: ArrayLike, L: float, k: float, center: float = 0) -> ArrayLike: return L / (1 + np.exp(-k * (x - center))) -def gen_logistic(p: ArrayLike, beta_max: float, beta_min: float, center: float, width: float) -> ArrayLike: +def gen_logistic(p: ArrayLike, sup: float, inf: float, center: float, width: float) -> ArrayLike: """Shifted logistic function.""" - return beta_min + logistic(p, L=beta_max - beta_min, k=4 / width, center=center) + return inf + logistic(p, L=sup - inf, k=4 / width, center=center) def beta( p: ArrayLike, beta_max: float = 1.7, beta_min: float = 0.3, - center: float = 0.25, - width: float = 0.5, + beta_center: float = 0.25, + beta_width: float = 0.5, **_: Any, ) -> ArrayLike: """Birth process.""" - return gen_logistic(p, beta_max, beta_min, center, width) + return gen_logistic(p, beta_max, beta_min, beta_center, beta_width) def delta( a: ArrayLike, delta_max: float = 1.7, delta_min: float = 0.3, - center: float = 0.1, - width: float = 0.2, - **kwargs: Any, + delta_center: float = 0.1, + delta_width: float = 0.2, + **_: Any, ) -> ArrayLike: """Death process.""" - return gen_logistic(a, delta_max, delta_min, center, width) + return gen_logistic(a, delta_max, delta_min, delta_center, delta_width) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 8418a9b52..b6b76a711 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -332,3 +332,16 @@ def test_cell_transition_regression_notparam( res = result.sort_index().sort_index(1) df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(1) np.testing.assert_almost_equal(res.values, df_expected.values, decimal=8) + + @pytest.mark.fast() + @pytest.mark.parametrize("temporal_key", ["celltype", "time", "missing"]) + def test_temporal_key_numeric(self, adata_time: AnnData, temporal_key: str): + problem = TemporalProblem(adata_time) + if temporal_key == "missing": + with pytest.raises(KeyError, match="Unable to find temporal key"): + problem = problem.prepare(temporal_key) + elif temporal_key == "celltype": + with pytest.raises(TypeError, match="Temporal key has to be of numeric type"): + problem = problem.prepare(temporal_key) + elif temporal_key == "time": + problem = problem.prepare(temporal_key) diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index 47270f8ad..38b876005 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List, Mapping, Optional import pytest @@ -95,3 +95,38 @@ def test_posterior_growth_rates(self, adata_time_marginal_estimations: AnnData): gr = prob.posterior_growth_rates assert isinstance(gr, np.ndarray) + + @pytest.mark.fast() + @pytest.mark.parametrize( + "marginal_kwargs", [{}, {"delta_width": 0.9}, {"delta_center": 0.9}, {"beta_width": 0.9}, {"beta_center": 0.9}] + ) + def test_marginal_kwargs(self, adata_time_marginal_estimations: AnnData, marginal_kwargs: Mapping[str, Any]): + t1, t2 = 0, 1 + adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] + adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2] + + prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2) + prob = prob.prepare( + x={"attr": "X"}, + y={"attr": "X"}, + a=True, + b=True, + proliferation_key="proliferation", + apoptosis_key="apoptosis", + ) + + gr1 = prob.prior_growth_rates + prob = prob.prepare( + x={"attr": "X"}, + y={"attr": "X"}, + a=True, + b=True, + proliferation_key="proliferation", + apoptosis_key="apoptosis", + marginal_kwargs=marginal_kwargs, + ) + gr2 = prob.prior_growth_rates + if len(marginal_kwargs) > 0: + assert not np.allclose(gr1, gr2) + else: + assert np.allclose(gr1, gr2) diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 82c9c5ab5..4d0aa22b5 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -7,6 +7,7 @@ from anndata import AnnData +from tests._utils import ATOL, RTOL from moscot.problems.time import TemporalProblem from moscot.solvers._output import BaseSolverOutput from tests.problems.conftest import ( @@ -135,6 +136,24 @@ def test_apoptosis_key_pipeline(self, adata_time: AnnData): problem.apoptosis_key = "new_apoptosis" assert problem.apoptosis_key == "new_apoptosis" + @pytest.mark.fast() + @pytest.mark.parametrize("scaling", [0.1, 1, 4]) + def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float): + keys = np.sort(np.unique(adata_time.obs["time"].values)) + adata_time = adata_time[adata_time.obs["time"].isin([keys[0], keys[1]])] + delta = keys[1] - keys[0] + problem = TemporalProblem(adata_time) + assert problem.proliferation_key is None + + problem.score_genes_for_marginals(gene_set_proliferation="human", gene_set_apoptosis="human") + assert problem.proliferation_key == "proliferation" + + problem = problem.prepare(time_key="time", marginal_kwargs={"scaling": scaling}) + prolif = adata_time[adata_time.obs["time"] == keys[0]].obs["proliferation"] + apopt = adata_time[adata_time.obs["time"] == keys[0]].obs["apoptosis"] + expected_marginals = np.exp((prolif - apopt) * delta / scaling) + np.testing.assert_allclose(problem[keys[0], keys[1]]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) + def test_cell_costs_source_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata=adata_time) problem = problem.prepare("time") diff --git a/tox.ini b/tox.ini index fb8d61247..1040311e3 100644 --- a/tox.ini +++ b/tox.ini @@ -113,26 +113,23 @@ commands = [testenv:lint] description = Perform linting. -basepython = python3.9 deps = pre-commit>=2.14.0 skip_install = true commands = pre-commit run --all-files --show-diff-on-failure {posargs:} [testenv:clean-docs] description = Clean the documentation artifacts. -basepython = python3.8 deps = skip_install = true changedir = {toxinidir}/docs -whitelist_externals = make +allowlist_externals = make commands = make clean [testenv:docs] description = Build the documentation. -basepython = python3.9 skip_install = false extras = docs -whitelist_externals = sphinx-build +allowlist_externals = sphinx-build commands = - sphinx-build --color -b html {toxinidir}/docs/source {toxinidir}/docs/build/html + sphinx-build --color -b html {toxinidir}/docs/source {toxinidir}/docs/build/html {posargs} python -c 'import pathlib; print(f"Documentation is available under:", pathlib.Path(f"{toxinidir}") / "docs" / "build" / "html" / "index.html")'