Skip to content

Commit

Permalink
Expose marginal kwargs for moscot.temporal and check for numeric ty…
Browse files Browse the repository at this point in the history
…pe of `temporal_key` (#449)

* make marginal_kwargs explicit in temporal problems

* introduce check for numeric dtype in temporal mixin

* add alternative way for marginal prior

* adapt tolerances in tests

* correct docs

* fix bug

* Fix math rendering

* fix test


Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
  • Loading branch information
2 people authored and lucaeyring committed Mar 15, 2023
1 parent 9806854 commit 50dbdca
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 50 deletions.
1 change: 0 additions & 1 deletion docs/source/api/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ OTT Backend

moscot.backends.ott.SinkhornSolver
moscot.backends.ott.GWSolver
moscot.backends.ott.FGWSolver
moscot.backends.ott.OTTOutput
moscot.backends.ott.OTTCost

Expand Down
1 change: 0 additions & 1 deletion docs/source/api/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ Generic Problems

SinkhornProblem
GWProblem
FGWProblem
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.mathjax",
"sphinx_autodoc_typehints",
"sphinx.ext.intersphinx",
"sphinx.ext.autosummary",
Expand Down
5 changes: 4 additions & 1 deletion docs/source/extensions/typed_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def process_return(lines: Iterable[str]) -> Iterator[str]:
m = re.fullmatch(r"(?P<param>\w+)\s+:\s+(?P<type>[\w.]+)", line)
if m:
# Once this is in scanpydoc, we can use the fancy hover stuff
yield f'**{m["param"]}** : :class:`~{m["type"]}`'
if m["param"]:
yield f'**{m["param"]}** : :class:`~{m["type"]}`'
else:
yield f':class:`~{m["type"]}`'
else:
yield line

Expand Down
40 changes: 33 additions & 7 deletions src/moscot/_docs/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
data
- If `data` is a :class:`str` this should correspond to a column in :attr:`anndata.AnnData.obs`.
The transport map is applied to the subset corresponding to the source distribution
(if `forward` is `True`) or target distribution (if `forward` is `False`) of that column.
(if `forward` is `True`) or target distribution (if `forward` is :obj:`False`) of that column.
- If `data` is a :class:npt.ArrayLike the transport map is applied to `data`.
- If `data` is a :class:`dict` then the keys should correspond to the tuple defining a single optimal
transport map and the value should be one of the two cases described above.
Expand All @@ -94,18 +94,20 @@
subset
Subset of :attr:`anndata.AnnData.obs` ``['{key}']`` values of which the policy is to be applied to.
"""
_marginal_kwargs = """\
_marginal_kwargs = r"""
marginal_kwargs
keyword arguments for :meth:`moscot.problems.BirthDeathProblem._estimate_marginals`, i.e.
for modeling the birth-death process. The keyword arguments
are either used for :func:`moscot.problems.time._utils.beta`, i.e. one of:
Keyword arguments for :meth:`~moscot.problems.BirthDeathProblem._estimate_marginals`. If ``'scaling'``
is in ``marginal_kwargs``, the left marginals are computed as
:math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})`.
Otherwise, the left marginals are computed using a birth-death process. The keyword arguments
are either used for :func:`~moscot.problems.time._utils.beta`, i.e. one of:
- beta_max: float
- beta_min: float
- beta_center: float
- beta_width: float
or for :func:`moscot.problems.time._utils.beta`, i.e. one of:
or for :func:`~moscot.problems.time._utils.delta`, i.e. one of:
- delta_max: float
- delta_min: float
Expand All @@ -127,13 +129,35 @@
_a = """\
a
Specifies the left marginals. If of type :class:`str` the left marginals are taken from
:attr:`anndata.AnnData.obs` ``['{a}']``. If `a` is `None` uniform marginals are used.
:attr:`anndata.AnnData.obs` ``['{a}']``. If ``a`` is `None` uniform marginals are used.
"""
_b = """\
b
Specifies the right marginals. If of type :class:`str` the right marginals are taken from
:attr:`anndata.AnnData.obs` ``['{b}']``. If `b` is `None` uniform marginals are used.
"""
_a_temporal = r"""
a
Specifies the left marginals. If
- ``a`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`,
- if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and
if ``a`` is `None`, marginals are computed based on a birth-death process as suggested in
:cite:`schiebinger:19`,
- if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and
if ``a`` is `None`, and additionally ``'scaling'`` is provided in ``marginal_kwargs``,
the marginals are computed as
:math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})`
rather than using a birth-death process,
- otherwise or if ``a`` is :obj:`False`, uniform marginals are used.
"""
_b_temporal = """\
b
Specifies the right marginals. If
- ``b`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`,
- if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run
uniform (mean of left marginals) right marginals are used,
- otherwise or if ``b`` is :obj:`False`, uniform marginals are used.
"""
_time_key = """\
time_key
Time point key in :attr:`anndata.AnnData.obs`.
Expand Down Expand Up @@ -400,6 +424,8 @@
converged=_converged,
a=_a,
b=_b,
a_temporal=_a_temporal,
b_temporal=_b_temporal,
time_key=_time_key,
spatial_key=_spatial_key,
batch_key=_batch_key,
Expand Down
5 changes: 1 addition & 4 deletions src/moscot/_docs/_docs_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@
for the corresponding plotting functions are stored.
See TODO Notebook for how :mod:`moscot.plotting` works.
"""
_return_cell_transition = """\
retun_cell_transition
Transition matrix of cells or groups of cells.
"""
_return_cell_transition = "Transition matrix of cells or groups of cells."
_notes_cell_transition = """\
To visualise the results, see :func:`moscot.pl.cell_transition`.
"""
Expand Down
24 changes: 16 additions & 8 deletions src/moscot/problems/base/_birth_death.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Union, Literal, Callable, Optional, Protocol, Sequence, TYPE_CHECKING
from types import MappingProxyType
from typing import Any, Union, Literal, Mapping, Callable, Optional, Protocol, Sequence, TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -68,7 +69,7 @@ def score_genes_for_marginals(
to proliferation and/or apoptosis must be passed.
Alternatively, proliferation and apoptosis genes for humans and mice are saved in :mod:`moscot`.
The gene scores will be used in :meth:`moscot.problems.TemporalProblem.prepare` to estimate the initial
The gene scores will be used in :meth:`~moscot.problems.CompoundBaseProblem.prepare` to estimate the initial
growth rates as suggested in :cite:`schiebinger:19`
Parameters
Expand Down Expand Up @@ -175,9 +176,10 @@ def _estimate_marginals(
source: bool,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
**kwargs: Any,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**_: Any,
) -> ArrayLike:
def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike:
def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) -> ArrayLike:
if key is None:
return np.zeros(adata.n_obs, dtype=float)
try:
Expand All @@ -189,16 +191,22 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike:
raise ValueError("Either `proliferation_key` or `apoptosis_key` must be specified.")
self.proliferation_key = proliferation_key
self.apoptosis_key = apoptosis_key
if "scaling" in marginal_kwargs:
beta_fn = delta_fn = lambda x, *args, **kwargs: x
scaling = marginal_kwargs["scaling"]
else:
beta_fn, delta_fn = beta, delta
scaling = 1
birth = estimate(proliferation_key, fn=beta_fn, **marginal_kwargs)
death = estimate(apoptosis_key, fn=delta_fn, **marginal_kwargs)

prior_growth = np.exp((birth - death) * self.delta / scaling)

birth = estimate(proliferation_key, fn=beta)
death = estimate(apoptosis_key, fn=delta)
prior_growth = np.exp((birth - death) * self.delta)
scaling = np.sum(prior_growth)
normalized_growth = prior_growth / scaling
if source:
self._scaling = scaling
self._prior_growth = prior_growth

return normalized_growth if source else np.full(self.adata_tgt.n_obs, fill_value=np.mean(normalized_growth))

# TODO(michalk8): temporary fix to satisfy the mixin, consider removing the mixin
Expand Down
8 changes: 4 additions & 4 deletions src/moscot/problems/base/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _(
return res if return_all else current_mass

@d.get_sections(base="BaseCompoundProblem_push", sections=["Parameters", "Raises"])
@d.dedent
@d.dedent # TODO(@MUCDK) document private _apply
def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]:
"""
Push mass from `start` to `end`.
Expand All @@ -371,7 +371,7 @@ def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]:
%(scale_by_marginals)s
kwargs
keyword arguments for :meth:`moscot.problems.CompoundProblem._apply`.
keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`.
Returns
-------
Expand All @@ -382,7 +382,7 @@ def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]:
return self._apply(*args, forward=True, **kwargs)

@d.get_sections(base="BaseCompoundProblem_pull", sections=["Parameters", "Raises"])
@d.dedent
@d.dedent # TODO(@MUCDK) document private functions
def pull(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]:
"""
Pull mass from `end` to `start`.
Expand All @@ -402,7 +402,7 @@ def pull(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]:
%(scale_by_marginals)s
kwargs
Keyword arguments for :meth:`moscot.problems.CompoundProblem._apply`.
keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`.
Returns
-------
Expand Down
15 changes: 10 additions & 5 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def prepare(
cost: Literal["sq_euclidean", "cosine", "bures", "unbalanced_bures"] = "sq_euclidean",
a: Optional[str] = None,
b: Optional[str] = None,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "TemporalProblem":
"""
Expand All @@ -61,8 +62,9 @@ def prepare(
%(joint_attr)s
%(policy)s
%(cost_lin)s
%(a)s
%(b)s
%(a_temporal)s
%(b_temporal)s
%(marginal_kwargs)s
%(kwargs_prepare)s
Expand All @@ -84,7 +86,7 @@ def prepare(
xy, x, y = handle_cost(xy=xy, x=kwargs.pop("x", None), y=kwargs.pop("y", None), cost=cost)

# TODO(michalk8): needs to be modified
marginal_kwargs = dict(kwargs.pop("marginal_kwargs", {}))
marginal_kwargs = dict(marginal_kwargs)
marginal_kwargs["proliferation_key"] = self.proliferation_key
marginal_kwargs["apoptosis_key"] = self.apoptosis_key
if a is None:
Expand Down Expand Up @@ -316,6 +318,7 @@ def prepare(
] = "sq_euclidean",
a: Optional[str] = None,
b: Optional[str] = None,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "LineageProblem":
"""
Expand All @@ -331,8 +334,9 @@ def prepare(
%(joint_attr)s
%(policy)s
%(cost)s
%(a)s
%(b)s
%(a_temporal)s
%(b_temporal)s
%(marginal_kwargs)s
%(kwargs_prepare)s
Returns
Expand Down Expand Up @@ -364,6 +368,7 @@ def prepare(
cost=cost,
a=a,
b=b,
marginal_kwargs=marginal_kwargs,
**kwargs,
)

Expand Down
15 changes: 12 additions & 3 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
import itertools

from pandas.api.types import infer_dtype, is_numeric_dtype
import pandas as pd

import numpy as np
Expand Down Expand Up @@ -459,6 +460,7 @@ def _get_data(
target_data,
)

@d_mixins.dedent
def compute_interpolated_distance(
self: TemporalMixinProtocol[K, B],
source: K,
Expand Down Expand Up @@ -502,7 +504,7 @@ def compute_interpolated_distance(
%(use_posterior_marginals)s
%(seed_sampling)s
%(backend)s
%(kwargs_divergence)
%(kwargs_divergence)s
Returns
-------
Expand Down Expand Up @@ -534,6 +536,7 @@ def compute_interpolated_distance(
point_cloud_1=intermediate_data, point_cloud_2=interpolation, backend=backend, **kwargs
)

@d_mixins.dedent
def compute_random_distance(
self: TemporalMixinProtocol[K, B],
source: K,
Expand All @@ -544,7 +547,7 @@ def compute_random_distance(
account_for_unbalancedness: bool = False,
posterior_marginals: bool = True,
seed: Optional[int] = None,
backend: Literal["ott"] = "ott",
backend: Literal["ott"] = "ott", # TODO: not used
**kwargs: Any,
) -> Numeric_t:
"""
Expand All @@ -566,7 +569,7 @@ def compute_random_distance(
%(use_posterior_marginals)s
%(seed_interpolation)s
%(backend)s
%(kwargs_divergence)
%(kwargs_divergence)s
Returns
-------
Expand All @@ -592,6 +595,7 @@ def compute_random_distance(
)
return self._compute_wasserstein_distance(intermediate_data, random_interpolation, **kwargs)

@d_mixins.dedent
def compute_time_point_distances(
self: TemporalMixinProtocol[K, B],
source: K,
Expand Down Expand Up @@ -770,6 +774,11 @@ def temporal_key(self) -> Optional[str]:
def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.")
if key is not None and not is_numeric_dtype(self.adata.obs[key]):
raise TypeError(
"Temporal key has to be of numeric type."
f"Found `adata.obs[{key!r}]` to be of type `{infer_dtype(self.adata.obs[key])}`."
)
self._temporal_key = key


Expand Down
18 changes: 9 additions & 9 deletions src/moscot/problems/time/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@ def logistic(x: ArrayLike, L: float, k: float, center: float = 0) -> ArrayLike:
return L / (1 + np.exp(-k * (x - center)))


def gen_logistic(p: ArrayLike, beta_max: float, beta_min: float, center: float, width: float) -> ArrayLike:
def gen_logistic(p: ArrayLike, sup: float, inf: float, center: float, width: float) -> ArrayLike:
"""Shifted logistic function."""
return beta_min + logistic(p, L=beta_max - beta_min, k=4 / width, center=center)
return inf + logistic(p, L=sup - inf, k=4 / width, center=center)


def beta(
p: ArrayLike,
beta_max: float = 1.7,
beta_min: float = 0.3,
center: float = 0.25,
width: float = 0.5,
beta_center: float = 0.25,
beta_width: float = 0.5,
**_: Any,
) -> ArrayLike:
"""Birth process."""
return gen_logistic(p, beta_max, beta_min, center, width)
return gen_logistic(p, beta_max, beta_min, beta_center, beta_width)


def delta(
a: ArrayLike,
delta_max: float = 1.7,
delta_min: float = 0.3,
center: float = 0.1,
width: float = 0.2,
**kwargs: Any,
delta_center: float = 0.1,
delta_width: float = 0.2,
**_: Any,
) -> ArrayLike:
"""Death process."""
return gen_logistic(a, delta_max, delta_min, center, width)
return gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
Loading

0 comments on commit 50dbdca

Please sign in to comment.