Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt callbacks and rename tag cost to cost_matrix #426

Merged
merged 14 commits into from
Dec 9, 2022
4 changes: 2 additions & 2 deletions src/moscot/_constants/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def wrapper(*args: Any, **kwargs: Any) -> "ErrorFormatterABC":
return wrapper


class ABCEnumMeta(EnumMeta, ABCMeta): # noqa: B024
class ABCEnumMeta(EnumMeta, ABCMeta):
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if getattr(cls, "__error_format__", None) is None:
raise TypeError(f"Can't instantiate class `{cls.__name__}` without `__error_format__` class attribute.")
Expand All @@ -45,7 +45,7 @@ def __new__(cls, clsname: str, superclasses: Tuple[type], attributedict: Dict[st
return res


class ErrorFormatterABC(ABC): # noqa: B024
class ErrorFormatterABC(ABC):
"""Mixin class that formats invalid value when constructing an enum."""

__error_format__ = "Invalid option `{0}` for `{1}`. Valid options are: `{2}`."
Expand Down
38 changes: 30 additions & 8 deletions src/moscot/_docs/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,31 @@
reference
`reference` in :class:`moscot.problems._subset_policy.StarPolicy`.
"""
_callback = """\
callback
Custom callback applied to each distribution as pre-processing step. Examples are given in TODO Link Notebook.
_xy_callback = """\
xy_callback
Custom callback applied to the linear term as pre-processing step. Examples are given in TODO Link Notebook.
"""
_callback_kwargs = """\
callback_kwargs
Keyword arguments for `callback`.
_xy_callback_kwargs = """\
xy_callback_kwargs
Keyword arguments for `xy_callback`.
"""
_x_callback = """\
x_callback
Custom callback applied to the source distribution of the quadratic term as pre-processing step.
Examples are given in TODO Link Notebook.
"""
_x_callback_kwargs = """\
x_callback_kwargs
Keyword arguments for `x_callback`.
"""
_y_callback = """\
y_callback
Custom callback applied to the target distribution of the quadratic term as pre-processing step.
Examples are given in TODO Link Notebook.
"""
_y_callback_kwargs = """\
x_callback_kwargs
Keyword arguments for `y_callback`.
"""
_epsilon = """\
epsilon
Expand Down Expand Up @@ -362,8 +380,12 @@
source=_source,
target=_target,
reference=_reference,
callback=_callback,
callback_kwargs=_callback_kwargs,
xy_callback=_xy_callback,
xy_callback_kwargs=_xy_callback_kwargs,
x_callback=_x_callback,
x_callback_kwargs=_x_callback_kwargs,
y_callback=_y_callback,
y_callback_kwargs=_y_callback_kwargs,
epsilon=_epsilon,
alpha=_alpha,
tau_a=_tau_a,
Expand Down
34 changes: 26 additions & 8 deletions src/moscot/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,12 @@ def wrap_solve(


def handle_joint_attr(
joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Any
) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]:
joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[str, Any]
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
if joint_attr is None:
if "callback" not in kwargs:
kwargs["callback"] = "local-pca"
else:
kwargs["callback"] = kwargs["callback"]
kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}}
if "xy_callback" not in kwargs:
kwargs["xy_callback"] = "local-pca"
kwargs.setdefault("xy_callback_kwargs", {})
return None, kwargs
if isinstance(joint_attr, str):
xy = {
Expand All @@ -84,7 +82,27 @@ def handle_joint_attr(
"y_key": joint_attr,
}
return xy, kwargs
if isinstance(joint_attr, Mapping):
if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space
joint_attr = dict(joint_attr)
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud
return {"x_attr": "X", "y_attr": "X"}, kwargs
if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud
if "key" not in joint_attr:
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
raise KeyError("`key` must be provided when `attr` is `obsm`.")
xy = {
"x_attr": "obsm",
"x_key": joint_attr["key"],
"y_attr": "obsm",
"y_key": joint_attr["key"],
}
return xy, kwargs
if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost
if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix
joint_attr.setdefault("cost", "custom")
joint_attr.setdefault("attr", "obsp")
kwargs["xy_callback"] = "cost-matrix"
kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]})
kwargs.setdefault("xy_callback_kwargs", {})
return joint_attr, kwargs
raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.")

Expand Down
31 changes: 18 additions & 13 deletions src/moscot/problems/base/_base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,10 @@ def pull(

@staticmethod
def _local_pca_callback(
term: Literal["x", "y", "xy"],
adata: AnnData,
adata_y: AnnData,
adata_y: Optional[AnnData] = None,
layer: Optional[str] = None,
return_linear: bool = True,
n_comps: int = 30,
scale: bool = False,
**kwargs: Any,
Expand All @@ -434,13 +434,19 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike:
return np.vstack([x, y])

if layer is None:
x, y, msg = adata.X, adata_y.X, "adata.X"
x, y, msg = adata.X, adata_y.X if adata_y is not None else None, "adata.X"
else:
x, y, msg = adata.layers[layer], adata_y.layers[layer], f"adata.layers[{layer!r}]"
x, y, msg = (
adata.layers[layer],
adata_y.layers[layer] if adata_y is not None else None,
f"adata.layers[{layer!r}]",
)

scaler = StandardScaler() if scale else None

if return_linear:
if term == "xy":
if y is None:
raise ValueError("When `term` is `xy` `adata_y` cannot be `None`.")
n = x.shape[0]
data = concat(x, y)
if data.shape[1] <= n_comps:
Expand All @@ -452,14 +458,13 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike:
if scaler is not None:
data = scaler.fit_transform(data)
return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)}

logger.info(f"Computing pca with `n_comps={n_comps}` for `x` and `y` using `{msg}`")
x = sc.pp.pca(x, n_comps=n_comps, **kwargs)
y = sc.pp.pca(y, n_comps=n_comps, **kwargs)
if scaler is not None:
x = scaler.fit_transform(x)
y = scaler.fit_transform(y)
return {"x": TaggedArray(x, tag=Tag.POINT_CLOUD), "y": TaggedArray(y, tag=Tag.POINT_CLOUD)}
if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None
logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`")
x = sc.pp.pca(x, n_comps=n_comps, **kwargs)
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
if scaler is not None:
x = scaler.fit_transform(x)
return {term: TaggedArray(x, tag=Tag.POINT_CLOUD)}
raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")

def _create_marginals(
self, adata: AnnData, *, source: bool, data: Optional[Union[bool, str, ArrayLike]] = None, **kwargs: Any
Expand Down
109 changes: 77 additions & 32 deletions src/moscot/problems/base/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@

K = TypeVar("K", bound=Hashable)
B = TypeVar("B", bound=OTProblem)
Callback_t = Callable[[AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]]
Callback_t = Callable[
[Literal["xy", "x", "y"], AnnData, Optional[AnnData]], Mapping[Literal["xy", "x", "y"], TaggedArray]
]
ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]]
# TODO(michalk8): future behavior
# ApplyOutput_t = Union[ArrayLike, Dict[Tuple[K, K], ArrayLike]]
Expand Down Expand Up @@ -97,11 +99,12 @@ def _valid_policies(self) -> Tuple[str, ...]:
# TODO(michalk8): refactor me
def _callback_handler(
self,
src: K,
tgt: K,
term: Literal["x", "y", "xy"],
key_1: K,
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca"], Callback_t],
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
Expand All @@ -112,20 +115,26 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
if not isinstance(val, TaggedArray):
raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.")

if callback is None:
return {}
if callback == "local-pca":
callback = problem._local_pca_callback

if not callable(callback):
raise TypeError("Callback is not a function.")
data = callback(problem.adata_src, problem.adata_tgt, **kwargs)
data = callback(term, problem.adata_src, problem.adata_tgt, **kwargs)
verify_data(data)
return data

# TODO(michalk8): refactor me
def _create_problems(
self,
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> Dict[Tuple[K, K], B]:
from moscot.problems.base._birth_death import BirthDeathProblem
Expand All @@ -143,11 +152,20 @@ def _create_problems(
tgt_name = tgt

problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
if callback is not None:
data = self._callback_handler(src, tgt, problem, callback=callback, **callback_kwargs)
kws = {**kwargs, **data} # type: ignore[arg-type]
else:
kws = kwargs

xy_data = self._callback_handler(
term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs
)

x_data = self._callback_handler(
term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs
)

y_data = self._callback_handler(
term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs
)

kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type]

if isinstance(problem, BirthDeathProblem):
kws["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined]
Expand All @@ -164,8 +182,12 @@ def prepare(
policy: Policy_t = "sequential",
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "BaseCompoundProblem[K,B]":
"""
Expand All @@ -177,8 +199,12 @@ def prepare(
%(policy)s
%(subset)s
%(reference)s
%(callback)s
%(callback_kwargs)s
%(xy_callback)s
%(x_callback)s
%(y_callback)s
%(xy_callback_kwargs)s
%(x_callback_kwargs)s
%(y_callback_kwargs)s
%(a)s
%(b)s

Expand All @@ -201,7 +227,15 @@ def prepare(
# TODO(michalk8): manager must be currently instantiated first, since `_create_problems` accesses the policy
# when refactoring the callback, consider changing this
self._problem_manager = ProblemManager(self, policy=policy)
problems = self._create_problems(callback=callback, callback_kwargs=callback_kwargs, **kwargs)
problems = self._create_problems(
xy_callback=xy_callback,
x_callback=x_callback,
y_callback=y_callback,
xy_callback_kwargs=xy_callback_kwargs,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
**kwargs,
)
self._problem_manager.add_problems(problems)

for p in self.problems.values():
Expand Down Expand Up @@ -584,21 +618,24 @@ def _create_policy(

def _callback_handler(
self,
src: K,
tgt: K,
term: Literal["xy", "x", "y"],
key_1: K,
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca", "cost-matrix"], Callback_t],
callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
# TODO(michalk8): better name?
if callback == "cost-matrix":
return self._cost_matrix_callback(src, tgt, **kwargs)

return super()._callback_handler(src, tgt, problem, callback=callback, **kwargs)
if callback == "cost-matrix":
return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs)
return super()._callback_handler(
term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs
)

def _cost_matrix_callback(
self, src: K, tgt: K, *, key: str, return_linear: bool = True, **_: Any
self, term: Literal["xy", "x", "y"], *, key: str, key_1: K, key_2: Optional[K] = None, **_: Any
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
if TYPE_CHECKING:
assert isinstance(self._policy, SubsetPolicy)
Expand All @@ -608,16 +645,24 @@ def _cost_matrix_callback(
except KeyError:
raise KeyError(f"Unable to fetch data from `adata.obsp[{key!r}]`.") from None

src_mask = self._policy.create_mask(src, allow_empty=False)
tgt_mask = self._policy.create_mask(tgt, allow_empty=False)
mask = self._policy.create_mask(key_1, allow_empty=False)

if return_linear:
linear_cost_matrix = data[src_mask, :][:, tgt_mask]
if term == "xy":
if key_2 is None:
raise ValueError("If `term` is `xy`, `key_2` cannot be `None`.")
mask_2 = self._policy.create_mask(key_2, allow_empty=False)

linear_cost_matrix = data[mask, :][:, mask_2]
if issparse(linear_cost_matrix):
logger.warning("Linear cost matrix being densified.")
linear_cost_matrix = linear_cost_matrix.A
MUCDK marked this conversation as resolved.
Show resolved Hide resolved
return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)}

return {
"x": TaggedArray(data[src_mask, :][:, src_mask], tag=Tag.COST_MATRIX),
"y": TaggedArray(data[tgt_mask, :][:, tgt_mask], tag=Tag.COST_MATRIX),
}
if term in ("x", "y"):
quad_cost_matrix = data[mask, :][:, mask_2]
if issparse(quad_cost_matrix):
logger.warning("Quadratic cost matrix being densified.")
quad_cost_matrix = quad_cost_matrix.A
return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)}

raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def prepare(
"""
self.batch_key = key
if joint_attr is None:
kwargs["callback"] = "local-pca"
kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}}
kwargs["xy_callback"] = "local-pca"
kwargs.setdefault("xy_callback_kwargs", {})
elif isinstance(joint_attr, str):
kwargs["xy"] = {
"x_attr": "obsm",
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def prepare(
lineage_attr.setdefault("attr", "obsp")
lineage_attr.setdefault("key", "cost_matrices")
lineage_attr.setdefault("cost", "custom")
lineage_attr.setdefault("tag", "cost")
lineage_attr.setdefault("tag", "cost_matrix")

x = y = lineage_attr

Expand Down
Loading