Skip to content

Commit

Permalink
adapt callbacks and rename tag cost to cost_matrix (#426)
Browse files Browse the repository at this point in the history
* rename tag cost to cost_matrix

* fix renaming

* [CI skip], adapt callback

* incorporate requested changes

* add test for quad custom callback

* adapt kwargs for callback

* fix handle_joint_attr

* incorporate requested changes
  • Loading branch information
MUCDK authored and lucaeyring committed Mar 15, 2023
1 parent c10c5b6 commit 74adb1e
Show file tree
Hide file tree
Showing 18 changed files with 244 additions and 120 deletions.
38 changes: 30 additions & 8 deletions src/moscot/_docs/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,31 @@
reference
`reference` in :class:`moscot.problems._subset_policy.StarPolicy`.
"""
_callback = """\
callback
Custom callback applied to each distribution as pre-processing step. Examples are given in TODO Link Notebook.
_xy_callback = """\
xy_callback
Custom callback applied to the linear term as pre-processing step. Examples are given in TODO Link Notebook.
"""
_callback_kwargs = """\
callback_kwargs
Keyword arguments for `callback`.
_xy_callback_kwargs = """\
xy_callback_kwargs
Keyword arguments for `xy_callback`.
"""
_x_callback = """\
x_callback
Custom callback applied to the source distribution of the quadratic term as pre-processing step.
Examples are given in TODO Link Notebook.
"""
_x_callback_kwargs = """\
x_callback_kwargs
Keyword arguments for `x_callback`.
"""
_y_callback = """\
y_callback
Custom callback applied to the target distribution of the quadratic term as pre-processing step.
Examples are given in TODO Link Notebook.
"""
_y_callback_kwargs = """\
x_callback_kwargs
Keyword arguments for `y_callback`.
"""
_epsilon = """\
epsilon
Expand Down Expand Up @@ -362,8 +380,12 @@
source=_source,
target=_target,
reference=_reference,
callback=_callback,
callback_kwargs=_callback_kwargs,
xy_callback=_xy_callback,
xy_callback_kwargs=_xy_callback_kwargs,
x_callback=_x_callback,
x_callback_kwargs=_x_callback_kwargs,
y_callback=_y_callback,
y_callback_kwargs=_y_callback_kwargs,
epsilon=_epsilon,
alpha=_alpha,
tau_a=_tau_a,
Expand Down
34 changes: 26 additions & 8 deletions src/moscot/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,12 @@ def wrap_solve(


def handle_joint_attr(
joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Any
) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]:
joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[str, Any]
) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]:
if joint_attr is None:
if "callback" not in kwargs:
kwargs["callback"] = "local-pca"
else:
kwargs["callback"] = kwargs["callback"]
kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}}
if "xy_callback" not in kwargs:
kwargs["xy_callback"] = "local-pca"
kwargs.setdefault("xy_callback_kwargs", {})
return None, kwargs
if isinstance(joint_attr, str):
xy = {
Expand All @@ -84,7 +82,27 @@ def handle_joint_attr(
"y_key": joint_attr,
}
return xy, kwargs
if isinstance(joint_attr, Mapping):
if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space
joint_attr = dict(joint_attr)
if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud
return {"x_attr": "X", "y_attr": "X"}, kwargs
if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud
if "key" not in joint_attr:
raise KeyError("`key` must be provided when `attr` is `obsm`.")
xy = {
"x_attr": "obsm",
"x_key": joint_attr["key"],
"y_attr": "obsm",
"y_key": joint_attr["key"],
}
return xy, kwargs
if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost
if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix
joint_attr.setdefault("cost", "custom")
joint_attr.setdefault("attr", "obsp")
kwargs["xy_callback"] = "cost-matrix"
kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]})
kwargs.setdefault("xy_callback_kwargs", {})
return joint_attr, kwargs
raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.")

Expand Down
31 changes: 18 additions & 13 deletions src/moscot/problems/base/_base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,10 @@ def pull(

@staticmethod
def _local_pca_callback(
term: Literal["x", "y", "xy"],
adata: AnnData,
adata_y: AnnData,
adata_y: Optional[AnnData] = None,
layer: Optional[str] = None,
return_linear: bool = True,
n_comps: int = 30,
scale: bool = False,
**kwargs: Any,
Expand All @@ -436,13 +436,19 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike:
return np.vstack([x, y])

if layer is None:
x, y, msg = adata.X, adata_y.X, "adata.X"
x, y, msg = adata.X, adata_y.X if adata_y is not None else None, "adata.X"
else:
x, y, msg = adata.layers[layer], adata_y.layers[layer], f"adata.layers[{layer!r}]"
x, y, msg = (
adata.layers[layer],
adata_y.layers[layer] if adata_y is not None else None,
f"adata.layers[{layer!r}]",
)

scaler = StandardScaler() if scale else None

if return_linear:
if term == "xy":
if y is None:
raise ValueError("When `term` is `xy` `adata_y` cannot be `None`.")
n = x.shape[0]
data = concat(x, y)
if data.shape[1] <= n_comps:
Expand All @@ -454,14 +460,13 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike:
if scaler is not None:
data = scaler.fit_transform(data)
return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)}

logger.info(f"Computing pca with `n_comps={n_comps}` for `x` and `y` using `{msg}`")
x = sc.pp.pca(x, n_comps=n_comps, **kwargs)
y = sc.pp.pca(y, n_comps=n_comps, **kwargs)
if scaler is not None:
x = scaler.fit_transform(x)
y = scaler.fit_transform(y)
return {"x": TaggedArray(x, tag=Tag.POINT_CLOUD), "y": TaggedArray(y, tag=Tag.POINT_CLOUD)}
if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None
logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`")
x = sc.pp.pca(x, n_comps=n_comps, **kwargs)
if scaler is not None:
x = scaler.fit_transform(x)
return {term: TaggedArray(x, tag=Tag.POINT_CLOUD)}
raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")

def _create_marginals(
self, adata: AnnData, *, source: bool, data: Optional[Union[bool, str, ArrayLike]] = None, **kwargs: Any
Expand Down
109 changes: 77 additions & 32 deletions src/moscot/problems/base/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@

K = TypeVar("K", bound=Hashable)
B = TypeVar("B", bound=OTProblem)
Callback_t = Callable[[AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]]
Callback_t = Callable[
[Literal["xy", "x", "y"], AnnData, Optional[AnnData]], Mapping[Literal["xy", "x", "y"], TaggedArray]
]
ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]]
# TODO(michalk8): future behavior
# ApplyOutput_t = Union[ArrayLike, Dict[Tuple[K, K], ArrayLike]]
Expand Down Expand Up @@ -97,11 +99,12 @@ def _valid_policies(self) -> Tuple[str, ...]:
# TODO(michalk8): refactor me
def _callback_handler(
self,
src: K,
tgt: K,
term: Literal["x", "y", "xy"],
key_1: K,
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca"], Callback_t],
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
Expand All @@ -112,20 +115,26 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
if not isinstance(val, TaggedArray):
raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.")

if callback is None:
return {}
if callback == "local-pca":
callback = problem._local_pca_callback

if not callable(callback):
raise TypeError("Callback is not a function.")
data = callback(problem.adata_src, problem.adata_tgt, **kwargs)
data = callback(term, problem.adata_src, problem.adata_tgt, **kwargs)
verify_data(data)
return data

# TODO(michalk8): refactor me
def _create_problems(
self,
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> Dict[Tuple[K, K], B]:
from moscot.problems.base._birth_death import BirthDeathProblem
Expand All @@ -143,11 +152,20 @@ def _create_problems(
tgt_name = tgt

problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
if callback is not None:
data = self._callback_handler(src, tgt, problem, callback=callback, **callback_kwargs)
kws = {**kwargs, **data} # type: ignore[arg-type]
else:
kws = kwargs

xy_data = self._callback_handler(
term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs
)

x_data = self._callback_handler(
term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs
)

y_data = self._callback_handler(
term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs
)

kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type]

if isinstance(problem, BirthDeathProblem):
kws["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined]
Expand All @@ -164,8 +182,12 @@ def prepare(
policy: Policy_t = "sequential",
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "BaseCompoundProblem[K,B]":
"""
Expand All @@ -177,8 +199,12 @@ def prepare(
%(policy)s
%(subset)s
%(reference)s
%(callback)s
%(callback_kwargs)s
%(xy_callback)s
%(x_callback)s
%(y_callback)s
%(xy_callback_kwargs)s
%(x_callback_kwargs)s
%(y_callback_kwargs)s
%(a)s
%(b)s
Expand All @@ -201,7 +227,15 @@ def prepare(
# TODO(michalk8): manager must be currently instantiated first, since `_create_problems` accesses the policy
# when refactoring the callback, consider changing this
self._problem_manager = ProblemManager(self, policy=policy)
problems = self._create_problems(callback=callback, callback_kwargs=callback_kwargs, **kwargs)
problems = self._create_problems(
xy_callback=xy_callback,
x_callback=x_callback,
y_callback=y_callback,
xy_callback_kwargs=xy_callback_kwargs,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
**kwargs,
)
self._problem_manager.add_problems(problems)

for p in self.problems.values():
Expand Down Expand Up @@ -584,21 +618,24 @@ def _create_policy(

def _callback_handler(
self,
src: K,
tgt: K,
term: Literal["xy", "x", "y"],
key_1: K,
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca", "cost-matrix"], Callback_t],
callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
# TODO(michalk8): better name?
if callback == "cost-matrix":
return self._cost_matrix_callback(src, tgt, **kwargs)

return super()._callback_handler(src, tgt, problem, callback=callback, **kwargs)
if callback == "cost-matrix":
return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs)
return super()._callback_handler(
term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs
)

def _cost_matrix_callback(
self, src: K, tgt: K, *, key: str, return_linear: bool = True, **_: Any
self, term: Literal["xy", "x", "y"], *, key: str, key_1: K, key_2: Optional[K] = None, **_: Any
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
if TYPE_CHECKING:
assert isinstance(self._policy, SubsetPolicy)
Expand All @@ -608,16 +645,24 @@ def _cost_matrix_callback(
except KeyError:
raise KeyError(f"Unable to fetch data from `adata.obsp[{key!r}]`.") from None

src_mask = self._policy.create_mask(src, allow_empty=False)
tgt_mask = self._policy.create_mask(tgt, allow_empty=False)
mask = self._policy.create_mask(key_1, allow_empty=False)

if return_linear:
linear_cost_matrix = data[src_mask, :][:, tgt_mask]
if term == "xy":
if key_2 is None:
raise ValueError("If `term` is `xy`, `key_2` cannot be `None`.")
mask_2 = self._policy.create_mask(key_2, allow_empty=False)

linear_cost_matrix = data[mask, :][:, mask_2]
if issparse(linear_cost_matrix):
logger.warning("Linear cost matrix being densified.")
linear_cost_matrix = linear_cost_matrix.A
return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)}

return {
"x": TaggedArray(data[src_mask, :][:, src_mask], tag=Tag.COST_MATRIX),
"y": TaggedArray(data[tgt_mask, :][:, tgt_mask], tag=Tag.COST_MATRIX),
}
if term in ("x", "y"):
quad_cost_matrix = data[mask, :][:, mask_2]
if issparse(quad_cost_matrix):
logger.warning("Quadratic cost matrix being densified.")
quad_cost_matrix = quad_cost_matrix.A
return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)}

raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def prepare(
"""
self.batch_key = key
if joint_attr is None:
kwargs["callback"] = "local-pca"
kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}}
kwargs["xy_callback"] = "local-pca"
kwargs.setdefault("xy_callback_kwargs", {})
elif isinstance(joint_attr, str):
kwargs["xy"] = {
"x_attr": "obsm",
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def prepare(
lineage_attr.setdefault("attr", "obsp")
lineage_attr.setdefault("key", "cost_matrices")
lineage_attr.setdefault("cost", "custom")
lineage_attr.setdefault("tag", "cost")
lineage_attr.setdefault("tag", "cost_matrix")

x = y = lineage_attr

Expand Down
2 changes: 1 addition & 1 deletion src/moscot/solvers/_tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_cost_function(cost: str, *, backend: Literal["ott"] = "ott", **kwargs: A
class Tag(ModeEnum):
"""Tag used to interpret array-like data in :class:`moscot.solvers.TaggedArray`."""

COST_MATRIX = "cost" #: Cost matrix.
COST_MATRIX = "cost_matrix" #: Cost matrix.
KERNEL = "kernel" #: Kernel matrix.
POINT_CLOUD = "point_cloud" #: Point cloud.

Expand Down
Loading

0 comments on commit 74adb1e

Please sign in to comment.