diff --git a/src/moscot/_docs/_docs.py b/src/moscot/_docs/_docs.py index b92a06c7f..29727d604 100644 --- a/src/moscot/_docs/_docs.py +++ b/src/moscot/_docs/_docs.py @@ -30,13 +30,31 @@ reference `reference` in :class:`moscot.problems._subset_policy.StarPolicy`. """ -_callback = """\ -callback - Custom callback applied to each distribution as pre-processing step. Examples are given in TODO Link Notebook. +_xy_callback = """\ +xy_callback + Custom callback applied to the linear term as pre-processing step. Examples are given in TODO Link Notebook. """ -_callback_kwargs = """\ -callback_kwargs - Keyword arguments for `callback`. +_xy_callback_kwargs = """\ +xy_callback_kwargs + Keyword arguments for `xy_callback`. +""" +_x_callback = """\ +x_callback + Custom callback applied to the source distribution of the quadratic term as pre-processing step. + Examples are given in TODO Link Notebook. +""" +_x_callback_kwargs = """\ +x_callback_kwargs + Keyword arguments for `x_callback`. +""" +_y_callback = """\ +y_callback + Custom callback applied to the target distribution of the quadratic term as pre-processing step. + Examples are given in TODO Link Notebook. +""" +_y_callback_kwargs = """\ +x_callback_kwargs + Keyword arguments for `y_callback`. """ _epsilon = """\ epsilon @@ -362,8 +380,12 @@ source=_source, target=_target, reference=_reference, - callback=_callback, - callback_kwargs=_callback_kwargs, + xy_callback=_xy_callback, + xy_callback_kwargs=_xy_callback_kwargs, + x_callback=_x_callback, + x_callback_kwargs=_x_callback_kwargs, + y_callback=_y_callback, + y_callback_kwargs=_y_callback_kwargs, epsilon=_epsilon, alpha=_alpha, tau_a=_tau_a, diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 6b6ebaaed..8417e1c9a 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -67,14 +67,12 @@ def wrap_solve( def handle_joint_attr( - joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Any -) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]: + joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[str, Any] +) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]: if joint_attr is None: - if "callback" not in kwargs: - kwargs["callback"] = "local-pca" - else: - kwargs["callback"] = kwargs["callback"] - kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}} + if "xy_callback" not in kwargs: + kwargs["xy_callback"] = "local-pca" + kwargs.setdefault("xy_callback_kwargs", {}) return None, kwargs if isinstance(joint_attr, str): xy = { @@ -84,7 +82,27 @@ def handle_joint_attr( "y_key": joint_attr, } return xy, kwargs - if isinstance(joint_attr, Mapping): + if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space + joint_attr = dict(joint_attr) + if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud + return {"x_attr": "X", "y_attr": "X"}, kwargs + if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud + if "key" not in joint_attr: + raise KeyError("`key` must be provided when `attr` is `obsm`.") + xy = { + "x_attr": "obsm", + "x_key": joint_attr["key"], + "y_attr": "obsm", + "y_key": joint_attr["key"], + } + return xy, kwargs + if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost + if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix + joint_attr.setdefault("cost", "custom") + joint_attr.setdefault("attr", "obsp") + kwargs["xy_callback"] = "cost-matrix" + kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]}) + kwargs.setdefault("xy_callback_kwargs", {}) return joint_attr, kwargs raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.") diff --git a/src/moscot/problems/base/_base_problem.py b/src/moscot/problems/base/_base_problem.py index 8f0d5ccee..a46d58d3a 100644 --- a/src/moscot/problems/base/_base_problem.py +++ b/src/moscot/problems/base/_base_problem.py @@ -420,10 +420,10 @@ def pull( @staticmethod def _local_pca_callback( + term: Literal["x", "y", "xy"], adata: AnnData, - adata_y: AnnData, + adata_y: Optional[AnnData] = None, layer: Optional[str] = None, - return_linear: bool = True, n_comps: int = 30, scale: bool = False, **kwargs: Any, @@ -436,13 +436,19 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: return np.vstack([x, y]) if layer is None: - x, y, msg = adata.X, adata_y.X, "adata.X" + x, y, msg = adata.X, adata_y.X if adata_y is not None else None, "adata.X" else: - x, y, msg = adata.layers[layer], adata_y.layers[layer], f"adata.layers[{layer!r}]" + x, y, msg = ( + adata.layers[layer], + adata_y.layers[layer] if adata_y is not None else None, + f"adata.layers[{layer!r}]", + ) scaler = StandardScaler() if scale else None - if return_linear: + if term == "xy": + if y is None: + raise ValueError("When `term` is `xy` `adata_y` cannot be `None`.") n = x.shape[0] data = concat(x, y) if data.shape[1] <= n_comps: @@ -454,14 +460,13 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: if scaler is not None: data = scaler.fit_transform(data) return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)} - - logger.info(f"Computing pca with `n_comps={n_comps}` for `x` and `y` using `{msg}`") - x = sc.pp.pca(x, n_comps=n_comps, **kwargs) - y = sc.pp.pca(y, n_comps=n_comps, **kwargs) - if scaler is not None: - x = scaler.fit_transform(x) - y = scaler.fit_transform(y) - return {"x": TaggedArray(x, tag=Tag.POINT_CLOUD), "y": TaggedArray(y, tag=Tag.POINT_CLOUD)} + if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None + logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`") + x = sc.pp.pca(x, n_comps=n_comps, **kwargs) + if scaler is not None: + x = scaler.fit_transform(x) + return {term: TaggedArray(x, tag=Tag.POINT_CLOUD)} + raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.") def _create_marginals( self, adata: AnnData, *, source: bool, data: Optional[Union[bool, str, ArrayLike]] = None, **kwargs: Any diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index dc11d4ab1..8b8ca61ca 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -48,7 +48,9 @@ K = TypeVar("K", bound=Hashable) B = TypeVar("B", bound=OTProblem) -Callback_t = Callable[[AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]] +Callback_t = Callable[ + [Literal["xy", "x", "y"], AnnData, Optional[AnnData]], Mapping[Literal["xy", "x", "y"], TaggedArray] +] ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]] # TODO(michalk8): future behavior # ApplyOutput_t = Union[ArrayLike, Dict[Tuple[K, K], ArrayLike]] @@ -97,11 +99,12 @@ def _valid_policies(self) -> Tuple[str, ...]: # TODO(michalk8): refactor me def _callback_handler( self, - src: K, - tgt: K, + term: Literal["x", "y", "xy"], + key_1: K, + key_2: K, problem: B, *, - callback: Union[Literal["local-pca"], Callback_t], + callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: @@ -112,20 +115,26 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: if not isinstance(val, TaggedArray): raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.") + if callback is None: + return {} if callback == "local-pca": callback = problem._local_pca_callback if not callable(callback): raise TypeError("Callback is not a function.") - data = callback(problem.adata_src, problem.adata_tgt, **kwargs) + data = callback(term, problem.adata_src, problem.adata_tgt, **kwargs) verify_data(data) return data # TODO(michalk8): refactor me def _create_problems( self, - callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, - callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> Dict[Tuple[K, K], B]: from moscot.problems.base._birth_death import BirthDeathProblem @@ -143,11 +152,20 @@ def _create_problems( tgt_name = tgt problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) - if callback is not None: - data = self._callback_handler(src, tgt, problem, callback=callback, **callback_kwargs) - kws = {**kwargs, **data} # type: ignore[arg-type] - else: - kws = kwargs + + xy_data = self._callback_handler( + term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs + ) + + x_data = self._callback_handler( + term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs + ) + + y_data = self._callback_handler( + term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs + ) + + kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type] if isinstance(problem, BirthDeathProblem): kws["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined] @@ -164,8 +182,12 @@ def prepare( policy: Policy_t = "sequential", subset: Optional[Sequence[Tuple[K, K]]] = None, reference: Optional[Any] = None, - callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, - callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "BaseCompoundProblem[K,B]": """ @@ -177,8 +199,12 @@ def prepare( %(policy)s %(subset)s %(reference)s - %(callback)s - %(callback_kwargs)s + %(xy_callback)s + %(x_callback)s + %(y_callback)s + %(xy_callback_kwargs)s + %(x_callback_kwargs)s + %(y_callback_kwargs)s %(a)s %(b)s @@ -201,7 +227,15 @@ def prepare( # TODO(michalk8): manager must be currently instantiated first, since `_create_problems` accesses the policy # when refactoring the callback, consider changing this self._problem_manager = ProblemManager(self, policy=policy) - problems = self._create_problems(callback=callback, callback_kwargs=callback_kwargs, **kwargs) + problems = self._create_problems( + xy_callback=xy_callback, + x_callback=x_callback, + y_callback=y_callback, + xy_callback_kwargs=xy_callback_kwargs, + x_callback_kwargs=x_callback_kwargs, + y_callback_kwargs=y_callback_kwargs, + **kwargs, + ) self._problem_manager.add_problems(problems) for p in self.problems.values(): @@ -584,21 +618,24 @@ def _create_policy( def _callback_handler( self, - src: K, - tgt: K, + term: Literal["xy", "x", "y"], + key_1: K, + key_2: K, problem: B, *, - callback: Union[Literal["local-pca", "cost-matrix"], Callback_t], + callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: # TODO(michalk8): better name? - if callback == "cost-matrix": - return self._cost_matrix_callback(src, tgt, **kwargs) - return super()._callback_handler(src, tgt, problem, callback=callback, **kwargs) + if callback == "cost-matrix": + return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs) + return super()._callback_handler( + term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs + ) def _cost_matrix_callback( - self, src: K, tgt: K, *, key: str, return_linear: bool = True, **_: Any + self, term: Literal["xy", "x", "y"], *, key: str, key_1: K, key_2: Optional[K] = None, **_: Any ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: if TYPE_CHECKING: assert isinstance(self._policy, SubsetPolicy) @@ -608,16 +645,24 @@ def _cost_matrix_callback( except KeyError: raise KeyError(f"Unable to fetch data from `adata.obsp[{key!r}]`.") from None - src_mask = self._policy.create_mask(src, allow_empty=False) - tgt_mask = self._policy.create_mask(tgt, allow_empty=False) + mask = self._policy.create_mask(key_1, allow_empty=False) - if return_linear: - linear_cost_matrix = data[src_mask, :][:, tgt_mask] + if term == "xy": + if key_2 is None: + raise ValueError("If `term` is `xy`, `key_2` cannot be `None`.") + mask_2 = self._policy.create_mask(key_2, allow_empty=False) + + linear_cost_matrix = data[mask, :][:, mask_2] if issparse(linear_cost_matrix): + logger.warning("Linear cost matrix being densified.") linear_cost_matrix = linear_cost_matrix.A return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)} - return { - "x": TaggedArray(data[src_mask, :][:, src_mask], tag=Tag.COST_MATRIX), - "y": TaggedArray(data[tgt_mask, :][:, tgt_mask], tag=Tag.COST_MATRIX), - } + if term in ("x", "y"): + quad_cost_matrix = data[mask, :][:, mask_2] + if issparse(quad_cost_matrix): + logger.warning("Quadratic cost matrix being densified.") + quad_cost_matrix = quad_cost_matrix.A + return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)} + + raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.") diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 3f0035d3d..7242b3eac 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -70,8 +70,8 @@ def prepare( """ self.batch_key = key if joint_attr is None: - kwargs["callback"] = "local-pca" - kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}} + kwargs["xy_callback"] = "local-pca" + kwargs.setdefault("xy_callback_kwargs", {}) elif isinstance(joint_attr, str): kwargs["xy"] = { "x_attr": "obsm", diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 2a10cded5..c630605c8 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -350,7 +350,7 @@ def prepare( lineage_attr.setdefault("attr", "obsp") lineage_attr.setdefault("key", "cost_matrices") lineage_attr.setdefault("cost", "custom") - lineage_attr.setdefault("tag", "cost") + lineage_attr.setdefault("tag", "cost_matrix") x = y = lineage_attr diff --git a/src/moscot/solvers/_tagged_array.py b/src/moscot/solvers/_tagged_array.py index 5a3651ec4..1df0ca5fc 100644 --- a/src/moscot/solvers/_tagged_array.py +++ b/src/moscot/solvers/_tagged_array.py @@ -27,7 +27,7 @@ def get_cost_function(cost: str, *, backend: Literal["ott"] = "ott", **kwargs: A class Tag(ModeEnum): """Tag used to interpret array-like data in :class:`moscot.solvers.TaggedArray`.""" - COST_MATRIX = "cost" #: Cost matrix. + COST_MATRIX = "cost_matrix" #: Cost matrix. KERNEL = "kernel" #: Kernel matrix. POINT_CLOUD = "point_cloud" #: Point cloud. diff --git a/tests/analysis_mixins/test_base_analysis.py b/tests/analysis_mixins/test_base_analysis.py index e03d036dc..306c6fb14 100644 --- a/tests/analysis_mixins/test_base_analysis.py +++ b/tests/analysis_mixins/test_base_analysis.py @@ -25,7 +25,7 @@ def test_sample_from_tmap_pipeline( source_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]) target_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]) problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="sequential", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="sequential", xy_callback="local-pca") problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1: @@ -71,7 +71,7 @@ def test_sample_from_tmap_pipeline( def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool, scale_by_marginals: bool): problem = CompoundProblemWithMixin(gt_temporal_adata) problem = problem.prepare( - "day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", callback="local-pca" + "day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", xy_callback="local-pca" ) problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) @@ -92,7 +92,7 @@ def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnDa key_2 = config["key_2"] config["key_3"] problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca") assert set(problem.problems.keys()) == {(key_1, key_2)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -133,7 +133,7 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD key_1 = config["key_1"] key_2 = config["key_2"] problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca") assert set(problem.problems.keys()) == {(key_1, key_2)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 5c5820975..b11eb7d3a 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Literal, Mapping +from typing import Any, Tuple, Literal, Mapping, Optional import os from pytest_mock import MockerFixture @@ -19,12 +19,28 @@ class TestCompoundProblem: @staticmethod - def callback( - adata: AnnData, adata_y: AnnData, sentinel: bool = False + def xy_callback( + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: assert sentinel assert isinstance(adata_y, AnnData) - return {"xy": TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + + @staticmethod + def x_callback( + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False + ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: + assert sentinel + assert isinstance(adata_y, AnnData) + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + + @staticmethod + def y_callback( + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False + ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: + assert sentinel + assert isinstance(adata_y, AnnData) + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} def test_sc_pipeline(self, adata_time: AnnData): expected_keys = [(0, 1), (1, 2)] @@ -55,7 +71,7 @@ def test_sc_pipeline(self, adata_time: AnnData): @pytest.mark.fast() def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scale: bool): subproblem = OTProblem(adata_time, adata_tgt=adata_time.copy()) - callback_kwargs = {"n_comps": 5, "scale": scale} + xy_callback_kwargs = {"n_comps": 5, "scale": scale} spy = mocker.spy(subproblem, "_local_pca_callback") problem = Problem(adata_time) @@ -65,32 +81,54 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal xy={"x_attr": "X", "y_attr": "X"}, key="time", policy="sequential", - callback="local-pca", - callback_kwargs=callback_kwargs, + xy_callback="local-pca", + xy_callback_kwargs=xy_callback_kwargs, ) assert isinstance(problem, CompoundProblem) assert isinstance(problem.problems, dict) - spy.assert_called_with(subproblem.adata_src, subproblem.adata_tgt, **callback_kwargs) + spy.assert_called_with("xy", subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) @pytest.mark.fast() - def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): + def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture): expected_keys = [(0, 1), (1, 2)] - spy = mocker.spy(TestCompoundProblem, "callback") + spy = mocker.spy(TestCompoundProblem, "xy_callback") problem = Problem(adata=adata_time) _ = problem.prepare( - xy={"x_attr": "X", "y_attr": "X"}, + xy=None, x={"attr": "X"}, y={"attr": "X"}, key="time", policy="sequential", - callback=TestCompoundProblem.callback, - callback_kwargs={"sentinel": True}, + xy_callback=TestCompoundProblem.xy_callback, + xy_callback_kwargs={"sentinel": True}, ) assert spy.call_count == len(expected_keys) + @pytest.mark.fast() + def test_custom_callback_quad(self, adata_time: AnnData, mocker: MockerFixture): + expected_keys = [(0, 1), (1, 2)] + spy_x = mocker.spy(TestCompoundProblem, "x_callback") + spy_y = mocker.spy(TestCompoundProblem, "y_callback") + + problem = Problem(adata=adata_time) + _ = problem.prepare( + xy=None, + x={"attr": "X"}, + y={"attr": "X"}, + key="time", + policy="sequential", + x_callback=TestCompoundProblem.x_callback, + y_callback=TestCompoundProblem.y_callback, + x_callback_kwargs={"sentinel": True}, + y_callback_kwargs={"sentinel": True}, + ) + + assert spy_x.call_count == len(expected_keys) + assert spy_y.call_count == len(expected_keys) + def test_different_passings_linear(self, adata_with_cost_matrix: AnnData): epsilon = 5 xy = {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"} @@ -100,7 +138,7 @@ def test_different_passings_linear(self, adata_with_cost_matrix: AnnData): p1_tmap = p1[0, 1].solution.transport_matrix p2 = Problem(adata_with_cost_matrix) - p2 = p2.prepare(key="batch", xy={"attr": "uns", "key": 0, "cost": "custom", "tag": "cost"}) + p2 = p2.prepare(key="batch", xy={"attr": "uns", "key": 0, "cost": "custom", "tag": "cost_matrix"}) p2 = p2.solve(epsilon=epsilon) p2_tmap = p2[0, 1].solution.transport_matrix diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index a4eb4113f..6276d6e35 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -47,8 +47,8 @@ def test_passing_scale(self, adata_x: AnnData, scale_cost: str): np.testing.assert_allclose(gt.matrix, sol.transport_matrix, rtol=RTOL, atol=ATOL) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -66,8 +66,8 @@ def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", " prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost np.testing.assert_equal(prob.xy.data_src, cost_matrix.to_numpy()) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -85,8 +85,8 @@ def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "k prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost np.testing.assert_equal(prob.x.data_src, cost_matrix.to_numpy()) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -115,7 +115,7 @@ def test_set_xy_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData): cm = rng.uniform(1, 10, size=(adata_x.n_obs, adata_y.n_obs)) cost_matrix = pd.DataFrame(index=adata_x.obs_names, columns=adata_y.obs_names, data=cm) - prob.set_xy(cost_matrix, tag="cost") + prob.set_xy(cost_matrix, tag="cost_matrix") assert prob.problem_kind == ProblemKind.QUAD_FUSED @@ -129,11 +129,11 @@ def test_set_x_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData): cm = rng.uniform(1, 10, size=(adata_x.n_obs, adata_x.n_obs)) cost_matrix = pd.DataFrame(index=adata_x.obs_names, columns=adata_x.obs_names, data=cm) - prob.set_x(cost_matrix, tag="cost") + prob.set_x(cost_matrix, tag="cost_matrix") cm = rng.uniform(1, 10, size=(adata_y.n_obs, adata_y.n_obs)) cost_matrix = pd.DataFrame(index=adata_y.obs_names, columns=adata_y.obs_names, data=cm) - prob.set_y(cost_matrix, tag="cost") + prob.set_y(cost_matrix, tag="cost_matrix") assert prob.problem_kind == ProblemKind.QUAD_FUSED diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index ba4b74afd..c8b000728 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -131,8 +131,8 @@ def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) @@ -159,8 +159,8 @@ def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].xy.data_src, np.ndarray) assert problem[0, 1].xy.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) @@ -186,8 +186,8 @@ def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].x.data_src, np.ndarray) assert problem[0, 1].x.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 3d74ba780..f0b970427 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -116,8 +116,8 @@ def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = GWProblem(adata=adata_time) @@ -142,8 +142,8 @@ def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].x.data_src, np.ndarray) assert problem[0, 1].x.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = GWProblem(adata=adata_time) diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 1328eb0f8..1e680ca25 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -55,8 +55,8 @@ def test_solve_balanced(self, adata_time: AnnData): # type: ignore[no-untyped-d assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = SinkhornProblem(adata=adata_time) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index a83bb4d3a..5ddb67539 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -25,9 +25,7 @@ class TestAlignmentProblem: @pytest.mark.fast() - @pytest.mark.parametrize( - "joint_attr", [{"x_attr": "X", "y_attr": "X"}] - ) # TODO(giovp): check that callback is correct + @pytest.mark.parametrize("joint_attr", [{"attr": "X"}]) # TODO(giovp): check that callback is correct def test_prepare_sequential(self, adata_space_rotate: AnnData, joint_attr: Optional[Mapping[str, Any]]): n_obs = adata_space_rotate.shape[0] // 3 # adata is made of 3 datasets n_var = adata_space_rotate.shape[1] @@ -112,7 +110,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))] key = ("0", "1") problem = AlignmentProblem(adata=adata_space_rotate) - problem = problem.prepare(batch_key="batch", joint_attr={"x_attr": "X", "y_attr": "X"}) + problem = problem.prepare(batch_key="batch", joint_attr={"attr": "X"}) problem = problem.solve(**args_to_check) solver = problem[key].solver.solver diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 8889275cd..b8fdfea55 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -27,9 +27,7 @@ class TestMappingProblem: @pytest.mark.fast() @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) - @pytest.mark.parametrize( - "joint_attr", [None, "X_pca", {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}] - ) + @pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}]) def test_prepare(self, adata_mapping: AnnData, sc_attr: Mapping[str, str], joint_attr: Optional[Mapping[str, str]]): adataref, adatasp = _adata_spatial_split(adata_mapping) expected_keys = {(i, "ref") for i in adatasp.obs.batch.cat.categories} diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 404e558bb..6cbc57fac 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -25,7 +25,7 @@ def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): problem = LineageProblem(adata=adata_time_barcodes) problem = problem.prepare( time_key="time", - lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost", "cost": "barcode_distance"}, + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, policy="sequential", ) problem = problem.solve() @@ -48,7 +48,7 @@ def test_trees_pipeline(self, adata_time_trees: AnnData): expected_keys = [(0, 1), (1, 2)] problem = LineageProblem(adata=adata_time_trees) problem = problem.prepare( - time_key="time", lineage_attr={"attr": "uns", "key": "trees", "tag": "cost", "cost": "leaf_distance"} + time_key="time", lineage_attr={"attr": "uns", "key": "trees", "tag": "cost_matrix", "cost": "leaf_distance"} ) problem = problem.solve(max_iterations=10) @@ -71,7 +71,7 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi problem = problem.prepare( time_key="time", policy="sequential", - lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost", "cost": "barcode_distance"}, + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, ) problem = problem.solve(**args_to_check) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 05ab9dd01..a0fc9ee3d 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -157,7 +157,7 @@ def test_compute_interpolated_distance_pipeline(self, gt_temporal_adata: AnnData key, subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -187,7 +187,7 @@ def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnDa subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -215,7 +215,7 @@ def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnDat subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -241,7 +241,7 @@ def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData): subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -264,7 +264,7 @@ def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData): subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 082306f2a..82c9c5ab5 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -205,7 +205,7 @@ def test_result_compares_to_wot(self, gt_temporal_adata: AnnData): key, subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) tp = tp.solve(epsilon=eps, scale_cost="mean", tau_a=lam1 / (lam1 + eps), tau_b=lam2 / (lam2 + eps))