diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 64ce1dc24..8417e1c9a 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -82,11 +82,11 @@ def handle_joint_attr( "y_key": joint_attr, } return xy, kwargs - if isinstance(joint_attr, Mapping): + if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space joint_attr = dict(joint_attr) - if "attr" in joint_attr and joint_attr["attr"] == "X": + if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud return {"x_attr": "X", "y_attr": "X"}, kwargs - if "attr" in joint_attr and joint_attr["attr"] == "obsm": + if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud if "key" not in joint_attr: raise KeyError("`key` must be provided when `attr` is `obsm`.") xy = { @@ -96,8 +96,8 @@ def handle_joint_attr( "y_key": joint_attr["key"], } return xy, kwargs - if joint_attr.get("tag", None) == "cost_matrix": - if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": + if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost + if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix joint_attr.setdefault("cost", "custom") joint_attr.setdefault("attr", "obsp") kwargs["xy_callback"] = "cost-matrix" diff --git a/src/moscot/problems/base/_base_problem.py b/src/moscot/problems/base/_base_problem.py index e58a682b3..cb77dcd22 100644 --- a/src/moscot/problems/base/_base_problem.py +++ b/src/moscot/problems/base/_base_problem.py @@ -458,7 +458,7 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: if scaler is not None: data = scaler.fit_transform(data) return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)} - if term in ("x", "y"): + if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`") x = sc.pp.pca(x, n_comps=n_comps, **kwargs) if scaler is not None: diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index d1ec3f5d4..8b8ca61ca 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -104,7 +104,7 @@ def _callback_handler( key_2: K, problem: B, *, - callback: Union[Literal["local-pca"], Callback_t], + callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: @@ -115,6 +115,8 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: if not isinstance(val, TaggedArray): raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.") + if callback is None: + return {} if callback == "local-pca": callback = problem._local_pca_callback @@ -150,24 +152,19 @@ def _create_problems( tgt_name = tgt problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) - if xy_callback is not None: - xy_data = self._callback_handler( - term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs - ) - else: - xy_data = {} - if x_callback is not None: - x_data = self._callback_handler( - term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs - ) - else: - x_data = {} - if y_callback is not None: - y_data = self._callback_handler( - term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs - ) - else: - y_data = {} + + xy_data = self._callback_handler( + term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs + ) + + x_data = self._callback_handler( + term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs + ) + + y_data = self._callback_handler( + term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs + ) + kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type] if isinstance(problem, BirthDeathProblem): @@ -626,7 +623,7 @@ def _callback_handler( key_2: K, problem: B, *, - callback: Union[Literal["local-pca", "cost-matrix"], Callback_t], + callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: # TODO(michalk8): better name? @@ -657,10 +654,15 @@ def _cost_matrix_callback( linear_cost_matrix = data[mask, :][:, mask_2] if issparse(linear_cost_matrix): + logger.warning("Linear cost matrix being densified.") linear_cost_matrix = linear_cost_matrix.A return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)} if term in ("x", "y"): - return {term: TaggedArray(data[mask, :][:, mask], tag=Tag.COST_MATRIX)} + quad_cost_matrix = data[mask, :][:, mask_2] + if issparse(quad_cost_matrix): + logger.warning("Quadratic cost matrix being densified.") + quad_cost_matrix = quad_cost_matrix.A + return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)} raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")