Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose marginal kwargs for moscot.temporal and check for numeric type of temporal_key #449

Merged
merged 14 commits into from
Feb 2, 2023
20 changes: 20 additions & 0 deletions src/moscot/_docs/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@
Specifies the right marginals. If of type :class:`str` the right marginals are taken from
:attr:`anndata.AnnData.obs` ``['{b}']``. If `b` is `None` uniform marginals are used.
"""
_a_temporal = """\
a
Specifies the left marginals. If of type :class:`str` the left marginals are taken from
:attr:`anndata.AnnData.obs` ``['{a}']``. If
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
:meth:`moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
if `a` is `None`, marginals are computed based on a birth-death process as suggested in
:cite:`schiebinger:19`. Otherwise, uniform marginals are used. If `a` is `False`, uniform
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
marginals are used.
"""
_b_temporal = """\
b
Specifies the right marginals. If of type :class:`str` the right marginals are taken from
:attr:`anndata.AnnData.obs` ``['{b}']``. If
:meth:`moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and
if `b` is `None`, marginals are computed based on a birth-death process as suggested in
:cite:`schiebinger:19`. Otherwise, uniform marginals are used. If `b` is `False`, uniform
marginals are used.
"""
_time_key = """\
time_key
Time point key in :attr:`anndata.AnnData.obs`.
Expand Down Expand Up @@ -400,6 +418,8 @@
converged=_converged,
a=_a,
b=_b,
a_temporal=_a_temporal,
b_temporal=_b_temporal,
time_key=_time_key,
spatial_key=_spatial_key,
batch_key=_batch_key,
Expand Down
12 changes: 7 additions & 5 deletions src/moscot/problems/base/_birth_death.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Union, Literal, Callable, Optional, Protocol, Sequence, TYPE_CHECKING
from types import MappingProxyType
from typing import Any, Union, Literal, Mapping, Callable, Optional, Protocol, Sequence, TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -175,9 +176,10 @@ def _estimate_marginals(
source: bool,
proliferation_key: Optional[str] = None,
apoptosis_key: Optional[str] = None,
**kwargs: Any,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**_: Any,
) -> ArrayLike:
def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike:
def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) -> ArrayLike:
if key is None:
return np.zeros(adata.n_obs, dtype=float)
try:
Expand All @@ -190,8 +192,8 @@ def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike]) -> ArrayLike:
self.proliferation_key = proliferation_key
self.apoptosis_key = apoptosis_key

birth = estimate(proliferation_key, fn=beta)
death = estimate(apoptosis_key, fn=delta)
birth = estimate(proliferation_key, fn=beta, **marginal_kwargs)
death = estimate(apoptosis_key, fn=delta, **marginal_kwargs)
prior_growth = np.exp((birth - death) * self.delta)
scaling = np.sum(prior_growth)
normalized_growth = prior_growth / scaling
Expand Down
15 changes: 10 additions & 5 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def prepare(
cost: Literal["sq_euclidean", "cosine", "bures", "unbalanced_bures"] = "sq_euclidean",
a: Optional[str] = None,
b: Optional[str] = None,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "TemporalProblem":
"""
Expand All @@ -61,8 +62,9 @@ def prepare(
%(joint_attr)s
%(policy)s
%(cost_lin)s
%(a)s
%(b)s
%(a_temporal)s
%(b_temporal)s
%(marginal_kwargs)s
%(kwargs_prepare)s


Expand All @@ -84,7 +86,7 @@ def prepare(
xy, x, y = handle_cost(xy=xy, x=kwargs.pop("x", None), y=kwargs.pop("y", None), cost=cost)

# TODO(michalk8): needs to be modified
marginal_kwargs = dict(kwargs.pop("marginal_kwargs", {}))
marginal_kwargs = dict(marginal_kwargs)
marginal_kwargs["proliferation_key"] = self.proliferation_key
marginal_kwargs["apoptosis_key"] = self.apoptosis_key
if a is None:
Expand Down Expand Up @@ -316,6 +318,7 @@ def prepare(
] = "sq_euclidean",
a: Optional[str] = None,
b: Optional[str] = None,
marginal_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "LineageProblem":
"""
Expand All @@ -331,8 +334,9 @@ def prepare(
%(joint_attr)s
%(policy)s
%(cost)s
%(a)s
%(b)s
%(a_temporal)s
%(b_temporal)s
%(marginal_kwargs)s
%(kwargs_prepare)s

Returns
Expand Down Expand Up @@ -364,6 +368,7 @@ def prepare(
cost=cost,
a=a,
b=b,
marginal_kwargs=marginal_kwargs,
**kwargs,
)

Expand Down
7 changes: 7 additions & 0 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,4 +768,11 @@ def temporal_key(self) -> Optional[str]:
def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find temporal key in `adata.obs[{key!r}]`.")
if key is not None and not (
pd.api.types.is_float_dtype(self.adata.obs[key]) or pd.api.types.is_integer_dtype(self.adata.obs[key])
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
):
raise TypeError(
"`temporal_key` has to be of numeric type."
+ f"Found `adata.obs[{key!r}]` to be of type {self.adata.obs[key].dtype}."
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
)
self._temporal_key = key
18 changes: 9 additions & 9 deletions src/moscot/problems/time/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@ def logistic(x: ArrayLike, L: float, k: float, center: float = 0) -> ArrayLike:
return L / (1 + np.exp(-k * (x - center)))


def gen_logistic(p: ArrayLike, beta_max: float, beta_min: float, center: float, width: float) -> ArrayLike:
def gen_logistic(p: ArrayLike, sup: float, inf: float, center: float, width: float) -> ArrayLike:
"""Shifted logistic function."""
return beta_min + logistic(p, L=beta_max - beta_min, k=4 / width, center=center)
return inf + logistic(p, L=sup - inf, k=4 / width, center=center)


def beta(
p: ArrayLike,
beta_max: float = 1.7,
beta_min: float = 0.3,
center: float = 0.25,
width: float = 0.5,
beta_center: float = 0.25,
beta_width: float = 0.5,
**_: Any,
) -> ArrayLike:
"""Birth process."""
return gen_logistic(p, beta_max, beta_min, center, width)
return gen_logistic(p, beta_max, beta_min, beta_center, beta_width)


def delta(
a: ArrayLike,
delta_max: float = 1.7,
delta_min: float = 0.3,
center: float = 0.1,
width: float = 0.2,
**kwargs: Any,
delta_center: float = 0.1,
delta_width: float = 0.2,
**_: Any,
) -> ArrayLike:
"""Death process."""
return gen_logistic(a, delta_max, delta_min, center, width)
return gen_logistic(a, delta_max, delta_min, delta_center, delta_width)
13 changes: 13 additions & 0 deletions tests/problems/time/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,16 @@ def test_cell_transition_regression_notparam(
res = result.sort_index().sort_index(1)
df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(1)
np.testing.assert_almost_equal(res.values, df_expected.values, decimal=8)

@pytest.mark.fast()
@pytest.mark.parametrize("temporal_key", ["celltype", "time", "missing"])
def test_temporal_key_numeric(self, adata_time: AnnData, temporal_key: str):
problem = TemporalProblem(adata_time)
if temporal_key == "missing":
with pytest.raises(KeyError, match="Unable to find temporal key"):
problem = problem.prepare(temporal_key)
elif temporal_key == "celltype":
with pytest.raises(TypeError, match="`temporal_key` has to be of numeric type"):
problem = problem.prepare(temporal_key)
elif temporal_key == "time":
problem = problem.prepare(temporal_key)
37 changes: 36 additions & 1 deletion tests/problems/time/test_temporal_base_problem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Mapping, Optional

import pytest

Expand Down Expand Up @@ -95,3 +95,38 @@ def test_posterior_growth_rates(self, adata_time_marginal_estimations: AnnData):

gr = prob.posterior_growth_rates
assert isinstance(gr, np.ndarray)

@pytest.mark.fast()
@pytest.mark.parametrize(
"marginal_kwargs", [{}, {"delta_width": 0.9}, {"delta_center": 0.9}, {"beta_width": 0.9}, {"beta_center": 0.9}]
)
def test_marginal_kwargs(self, adata_time_marginal_estimations: AnnData, marginal_kwargs: Mapping[str, Any]):
t1, t2 = 0, 1
adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1]
adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2]

prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2)
prob = prob.prepare(
x={"attr": "X"},
y={"attr": "X"},
a=True,
b=True,
proliferation_key="proliferation",
apoptosis_key="apoptosis",
)

gr1 = prob.prior_growth_rates
prob = prob.prepare(
x={"attr": "X"},
y={"attr": "X"},
a=True,
b=True,
proliferation_key="proliferation",
apoptosis_key="apoptosis",
marginal_kwargs=marginal_kwargs,
)
gr2 = prob.prior_growth_rates
if len(marginal_kwargs) > 0:
assert not np.allclose(gr1, gr2)
else:
assert np.allclose(gr1, gr2)