From 80d0e494a87db6a34948e2da8c88e50ccb583d00 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 8 Dec 2022 13:44:08 +0100 Subject: [PATCH 01/14] rename tag cost to cost_matrix --- src/moscot/_constants/_enum.py | 4 ++-- src/moscot/problems/time/_lineage.py | 2 +- src/moscot/solvers/_base_solver.py | 2 +- src/moscot/solvers/_output.py | 2 +- src/moscot/solvers/_tagged_array.py | 2 +- tests/problems/base/test_compound_problem.py | 2 +- tests/problems/generic/test_fgw_problem.py | 4 ++-- tests/problems/generic/test_gw_problem.py | 4 ++-- tests/problems/time/test_lineage_problem.py | 6 +++--- 9 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/moscot/_constants/_enum.py b/src/moscot/_constants/_enum.py index da76458a2..42b0e48d9 100644 --- a/src/moscot/_constants/_enum.py +++ b/src/moscot/_constants/_enum.py @@ -33,7 +33,7 @@ def wrapper(*args: Any, **kwargs: Any) -> "ErrorFormatterABC": return wrapper -class ABCEnumMeta(EnumMeta, ABCMeta): # noqa: B024 +class ABCEnumMeta(EnumMeta, ABCMeta): def __call__(cls, *args: Any, **kwargs: Any) -> Any: if getattr(cls, "__error_format__", None) is None: raise TypeError(f"Can't instantiate class `{cls.__name__}` without `__error_format__` class attribute.") @@ -45,7 +45,7 @@ def __new__(cls, clsname: str, superclasses: Tuple[type], attributedict: Dict[st return res -class ErrorFormatterABC(ABC): # noqa: B024 +class ErrorFormatterABC(ABC): """Mixin class that formats invalid value when constructing an enum.""" __error_format__ = "Invalid option `{0}` for `{1}`. Valid options are: `{2}`." diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 2a10cded5..c630605c8 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -350,7 +350,7 @@ def prepare( lineage_attr.setdefault("attr", "obsp") lineage_attr.setdefault("key", "cost_matrices") lineage_attr.setdefault("cost", "custom") - lineage_attr.setdefault("tag", "cost") + lineage_attr.setdefault("tag", "cost_matrix") x = y = lineage_attr diff --git a/src/moscot/solvers/_base_solver.py b/src/moscot/solvers/_base_solver.py index d4fe6d857..efd3fe39b 100644 --- a/src/moscot/solvers/_base_solver.py +++ b/src/moscot/solvers/_base_solver.py @@ -147,7 +147,7 @@ def __call__(self, **kwargs: Any) -> O: @d.get_sections(base="OTSolver", sections=["Parameters", "Raises"]) @d.dedent -class OTSolver(TagConverterMixin, BaseSolver[O], ABC): # noqa: B024 +class OTSolver(TagConverterMixin, BaseSolver[O], ABC): """Base class for optimal transport solvers.""" def __call__( diff --git a/src/moscot/solvers/_output.py b/src/moscot/solvers/_output.py index 5e65f9c96..6d5de06c5 100644 --- a/src/moscot/solvers/_output.py +++ b/src/moscot/solvers/_output.py @@ -209,7 +209,7 @@ def __str__(self) -> str: return f"{self.__class__.__name__}[{self._format_params(str)}]" -class MatrixSolverOutput(BaseSolverOutput, ABC): # noqa: B024 +class MatrixSolverOutput(BaseSolverOutput, ABC): """Optimal transport output with materialized :attr:`transport_matrix`. Parameters diff --git a/src/moscot/solvers/_tagged_array.py b/src/moscot/solvers/_tagged_array.py index 8d057e066..2e205561e 100644 --- a/src/moscot/solvers/_tagged_array.py +++ b/src/moscot/solvers/_tagged_array.py @@ -27,7 +27,7 @@ def get_cost_function(cost: str, *, backend: Literal["ott"] = "ott", **kwargs: A class Tag(ModeEnum): """Tag used to interpret array-like data in :class:`moscot.solvers.TaggedArray`.""" - COST_MATRIX = "cost" #: Cost matrix. + COST_MATRIX = "cost_matrix" #: Cost matrix. KERNEL = "kernel" #: Kernel matrix. POINT_CLOUD = "point_cloud" #: Point cloud. diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 5c5820975..b1a110cfe 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -100,7 +100,7 @@ def test_different_passings_linear(self, adata_with_cost_matrix: AnnData): p1_tmap = p1[0, 1].solution.transport_matrix p2 = Problem(adata_with_cost_matrix) - p2 = p2.prepare(key="batch", xy={"attr": "uns", "key": 0, "cost": "custom", "tag": "cost"}) + p2 = p2.prepare(key="batch", xy={"attr": "uns", "key": 0, "cost": "custom", "tag": "cost_matrix"}) p2 = p2.solve(epsilon=epsilon) p2_tmap = p2[0, 1].solution.transport_matrix diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index b918eeedc..011b36f4e 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -8,9 +8,9 @@ from anndata import AnnData -from moscot.problems.base import OTProblem +from moscot.problems.base import OTProblem # type:ignore[attr-defined] from moscot.solvers._output import BaseSolverOutput -from moscot.problems.generic import FGWProblem +from moscot.problems.generic import FGWProblem # type:ignore[attr-defined] from tests.problems.conftest import ( fgw_args_1, fgw_args_2, diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 333c07ea2..78e8e95b4 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -8,9 +8,9 @@ from anndata import AnnData -from moscot.problems.base import OTProblem +from moscot.problems.base import OTProblem # type:ignore[attr-defined] from moscot.solvers._output import BaseSolverOutput -from moscot.problems.generic import GWProblem +from moscot.problems.generic import GWProblem # type:ignore[attr-defined] from tests.problems.conftest import ( gw_args_1, gw_args_2, diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 404e558bb..6cbc57fac 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -25,7 +25,7 @@ def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): problem = LineageProblem(adata=adata_time_barcodes) problem = problem.prepare( time_key="time", - lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost", "cost": "barcode_distance"}, + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, policy="sequential", ) problem = problem.solve() @@ -48,7 +48,7 @@ def test_trees_pipeline(self, adata_time_trees: AnnData): expected_keys = [(0, 1), (1, 2)] problem = LineageProblem(adata=adata_time_trees) problem = problem.prepare( - time_key="time", lineage_attr={"attr": "uns", "key": "trees", "tag": "cost", "cost": "leaf_distance"} + time_key="time", lineage_attr={"attr": "uns", "key": "trees", "tag": "cost_matrix", "cost": "leaf_distance"} ) problem = problem.solve(max_iterations=10) @@ -71,7 +71,7 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi problem = problem.prepare( time_key="time", policy="sequential", - lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost", "cost": "barcode_distance"}, + lineage_attr={"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, ) problem = problem.solve(**args_to_check) From d1b514f9891bf40b4b80817aa4823c800fc8be59 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 8 Dec 2022 14:04:37 +0100 Subject: [PATCH 02/14] fix renaming --- tests/problems/base/test_general_problem.py | 18 +++++++++--------- tests/problems/generic/test_fgw_problem.py | 16 +++++++++------- tests/problems/generic/test_gw_problem.py | 8 ++++---- .../problems/generic/test_sinkhorn_problem.py | 4 ++-- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index a4eb4113f..6276d6e35 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -47,8 +47,8 @@ def test_passing_scale(self, adata_x: AnnData, scale_cost: str): np.testing.assert_allclose(gt.matrix, sol.transport_matrix, rtol=RTOL, atol=ATOL) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -66,8 +66,8 @@ def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", " prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost np.testing.assert_equal(prob.xy.data_src, cost_matrix.to_numpy()) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -85,8 +85,8 @@ def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "k prob = prob.solve(max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost np.testing.assert_equal(prob.x.data_src, cost_matrix.to_numpy()) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -115,7 +115,7 @@ def test_set_xy_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData): cm = rng.uniform(1, 10, size=(adata_x.n_obs, adata_y.n_obs)) cost_matrix = pd.DataFrame(index=adata_x.obs_names, columns=adata_y.obs_names, data=cm) - prob.set_xy(cost_matrix, tag="cost") + prob.set_xy(cost_matrix, tag="cost_matrix") assert prob.problem_kind == ProblemKind.QUAD_FUSED @@ -129,11 +129,11 @@ def test_set_x_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData): cm = rng.uniform(1, 10, size=(adata_x.n_obs, adata_x.n_obs)) cost_matrix = pd.DataFrame(index=adata_x.obs_names, columns=adata_x.obs_names, data=cm) - prob.set_x(cost_matrix, tag="cost") + prob.set_x(cost_matrix, tag="cost_matrix") cm = rng.uniform(1, 10, size=(adata_y.n_obs, adata_y.n_obs)) cost_matrix = pd.DataFrame(index=adata_y.obs_names, columns=adata_y.obs_names, data=cm) - prob.set_y(cost_matrix, tag="cost") + prob.set_y(cost_matrix, tag="cost_matrix") assert prob.problem_kind == ProblemKind.QUAD_FUSED diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 011b36f4e..acd6924c8 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -116,7 +116,9 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin assert getattr(geom, val) == args_to_check[arg] @pytest.mark.fast() - @pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)]) + @pytest.mark.parametrize( + "cost_matrix", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)] + ) def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): problem = FGWProblem(adata=adata_time) problem = problem.prepare( @@ -131,8 +133,8 @@ def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) @@ -159,8 +161,8 @@ def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].xy.data_src, np.ndarray) assert problem[0, 1].xy.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) @@ -186,8 +188,8 @@ def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].x.data_src, np.ndarray) assert problem[0, 1].x.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = FGWProblem(adata=adata_time) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 78e8e95b4..9925dbb97 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -116,8 +116,8 @@ def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_x(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = GWProblem(adata=adata_time) @@ -142,8 +142,8 @@ def test_set_x(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): assert isinstance(problem[0, 1].x.data_src, np.ndarray) assert problem[0, 1].x.data_tgt is None - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_y(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = GWProblem(adata=adata_time) diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 1328eb0f8..1e680ca25 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -55,8 +55,8 @@ def test_solve_balanced(self, adata_time: AnnData): # type: ignore[no-untyped-d assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys - @pytest.mark.parametrize("tag", ["cost", "kernel"]) - def test_set_xy(self, adata_time: AnnData, tag: Literal["cost", "kernel"]): + @pytest.mark.parametrize("tag", ["cost_matrix", "kernel"]) + def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]): rng = np.random.RandomState(42) adata_time = adata_time[adata_time.obs["time"].isin((0, 1))].copy() problem = SinkhornProblem(adata=adata_time) From c5f43c2ccae81cfced4c8fc8066109482edb8223 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 8 Dec 2022 14:21:36 +0100 Subject: [PATCH 03/14] fix renaming --- tests/problems/generic/test_fgw_problem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index acd6924c8..e8f4f3d72 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -117,7 +117,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin @pytest.mark.fast() @pytest.mark.parametrize( - "cost_matrix", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)] + "cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)] ) def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): problem = FGWProblem(adata=adata_time) From 8f9ec76d6d5be8646151802c451a591f9f9f1661 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Thu, 8 Dec 2022 18:14:12 +0100 Subject: [PATCH 04/14] [CI skip], adapt callback --- src/moscot/problems/_utils.py | 8 +- src/moscot/problems/base/_base_problem.py | 31 ++++--- src/moscot/problems/base/_compound_problem.py | 88 +++++++++++++------ src/moscot/problems/generic/_generic.py | 4 +- tests/analysis_mixins/test_base_analysis.py | 8 +- tests/problems/base/test_compound_problem.py | 8 +- tests/problems/time/test_mixins.py | 10 +-- tests/problems/time/test_temporal_problem.py | 2 +- 8 files changed, 100 insertions(+), 59 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 6b6ebaaed..e369b9efa 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -70,11 +70,11 @@ def handle_joint_attr( joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Any ) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]: if joint_attr is None: - if "callback" not in kwargs: - kwargs["callback"] = "local-pca" + if "xy_callback" not in kwargs: + kwargs["xy_callback"] = "local-pca" else: - kwargs["callback"] = kwargs["callback"] - kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}} + kwargs["xy_callback"] = kwargs["xy_callback"] + kwargs["xy_callback_kwargs"] = {**kwargs.get("xy_callback_kwargs", {})} return None, kwargs if isinstance(joint_attr, str): xy = { diff --git a/src/moscot/problems/base/_base_problem.py b/src/moscot/problems/base/_base_problem.py index 8c89912f2..e4be971dd 100644 --- a/src/moscot/problems/base/_base_problem.py +++ b/src/moscot/problems/base/_base_problem.py @@ -418,10 +418,10 @@ def pull( @staticmethod def _local_pca_callback( + term: Literal["x", "y", "xy"], adata: AnnData, - adata_y: AnnData, + adata_y: Optional[AnnData] = None, layer: Optional[str] = None, - return_linear: bool = True, n_comps: int = 30, scale: bool = False, **kwargs: Any, @@ -434,13 +434,19 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: return np.vstack([x, y]) if layer is None: - x, y, msg = adata.X, adata_y.X, "adata.X" + x, y, msg = adata.X, adata_y.X if adata_y is not None else None, "adata.X" else: - x, y, msg = adata.layers[layer], adata_y.layers[layer], f"adata.layers[{layer!r}]" + x, y, msg = ( + adata.layers[layer], + adata_y.layers[layer] if adata_y is not None else None, + f"adata.layers[{layer!r}]", + ) scaler = StandardScaler() if scale else None - if return_linear: + if term == "xy": + if adata_y is None: + raise ValueError("When `term` is `xy` `adata_y` cannot be `None`.") n = x.shape[0] data = concat(x, y) if data.shape[1] <= n_comps: @@ -452,14 +458,13 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: if scaler is not None: data = scaler.fit_transform(data) return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)} - - logger.info(f"Computing pca with `n_comps={n_comps}` for `x` and `y` using `{msg}`") - x = sc.pp.pca(x, n_comps=n_comps, **kwargs) - y = sc.pp.pca(y, n_comps=n_comps, **kwargs) - if scaler is not None: - x = scaler.fit_transform(x) - y = scaler.fit_transform(y) - return {"x": TaggedArray(x, tag=Tag.POINT_CLOUD), "y": TaggedArray(y, tag=Tag.POINT_CLOUD)} + if term in ("x", "y"): + logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`") + x = sc.pp.pca(x, n_comps=n_comps, **kwargs) + if scaler is not None: + x = scaler.fit_transform(x) + return {term: TaggedArray(x, tag=Tag.POINT_CLOUD)} + raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.") def _create_marginals( self, adata: AnnData, *, source: bool, data: Optional[Union[bool, str, ArrayLike]] = None, **kwargs: Any diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index dc11d4ab1..fe1ee083c 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -48,7 +48,7 @@ K = TypeVar("K", bound=Hashable) B = TypeVar("B", bound=OTProblem) -Callback_t = Callable[[AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]] +Callback_t = Callable[[Literal["xy", "x", "y"], AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]] ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]] # TODO(michalk8): future behavior # ApplyOutput_t = Union[ArrayLike, Dict[Tuple[K, K], ArrayLike]] @@ -97,8 +97,9 @@ def _valid_policies(self) -> Tuple[str, ...]: # TODO(michalk8): refactor me def _callback_handler( self, - src: K, - tgt: K, + term: Literal["x", "y", "xy"], + key_1: K, + key_2: K, problem: B, *, callback: Union[Literal["local-pca"], Callback_t], @@ -117,15 +118,19 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: if not callable(callback): raise TypeError("Callback is not a function.") - data = callback(problem.adata_src, problem.adata_tgt, **kwargs) + data = callback(term=term, adata=problem.adata_src, adata_y=problem.adata_tgt, **kwargs) verify_data(data) return data # TODO(michalk8): refactor me def _create_problems( self, - callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, - callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> Dict[Tuple[K, K], B]: from moscot.problems.base._birth_death import BirthDeathProblem @@ -143,11 +148,25 @@ def _create_problems( tgt_name = tgt problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) - if callback is not None: - data = self._callback_handler(src, tgt, problem, callback=callback, **callback_kwargs) - kws = {**kwargs, **data} # type: ignore[arg-type] + if xy_callback is not None: + xy_data = self._callback_handler( + term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs + ) else: - kws = kwargs + xy_data = {} + if x_callback is not None: + x_data = self._callback_handler( + term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs + ) + else: + x_data = {} + if y_callback is not None: + y_data = self._callback_handler( + term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs + ) + else: + y_data = {} + kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type] if isinstance(problem, BirthDeathProblem): kws["proliferation_key"] = self.proliferation_key # type: ignore[attr-defined] @@ -164,8 +183,12 @@ def prepare( policy: Policy_t = "sequential", subset: Optional[Sequence[Tuple[K, K]]] = None, reference: Optional[Any] = None, - callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, - callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + xy_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), + y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "BaseCompoundProblem[K,B]": """ @@ -201,7 +224,15 @@ def prepare( # TODO(michalk8): manager must be currently instantiated first, since `_create_problems` accesses the policy # when refactoring the callback, consider changing this self._problem_manager = ProblemManager(self, policy=policy) - problems = self._create_problems(callback=callback, callback_kwargs=callback_kwargs, **kwargs) + problems = self._create_problems( + xy_callback=xy_callback, + x_callback=x_callback, + y_callback=y_callback, + xy_callback_kwargs=xy_callback_kwargs, + x_callback_kwargs=x_callback_kwargs, + y_callback_kwargs=y_callback_kwargs, + **kwargs, + ) self._problem_manager.add_problems(problems) for p in self.problems.values(): @@ -584,8 +615,9 @@ def _create_policy( def _callback_handler( self, - src: K, - tgt: K, + term: Literal["xy", "x", "y"], + key_1: K, + key_2: K, problem: B, *, callback: Union[Literal["local-pca", "cost-matrix"], Callback_t], @@ -593,12 +625,14 @@ def _callback_handler( ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: # TODO(michalk8): better name? if callback == "cost-matrix": - return self._cost_matrix_callback(src, tgt, **kwargs) + return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs) - return super()._callback_handler(src, tgt, problem, callback=callback, **kwargs) + return super()._callback_handler( + term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs + ) def _cost_matrix_callback( - self, src: K, tgt: K, *, key: str, return_linear: bool = True, **_: Any + self, term: Literal["xy", "x", "y"], *, key: str, key_1: K, key_2: Optional[K] = None, **_: Any ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: if TYPE_CHECKING: assert isinstance(self._policy, SubsetPolicy) @@ -608,16 +642,18 @@ def _cost_matrix_callback( except KeyError: raise KeyError(f"Unable to fetch data from `adata.obsp[{key!r}]`.") from None - src_mask = self._policy.create_mask(src, allow_empty=False) - tgt_mask = self._policy.create_mask(tgt, allow_empty=False) + mask = self._policy.create_mask(key_1, allow_empty=False) + + if term == "xy": + mask_2 = self._policy.create_mask(key_2, allow_empty=False) - if return_linear: - linear_cost_matrix = data[src_mask, :][:, tgt_mask] + linear_cost_matrix = data[mask, :][:, mask_2] if issparse(linear_cost_matrix): linear_cost_matrix = linear_cost_matrix.A return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)} - return { - "x": TaggedArray(data[src_mask, :][:, src_mask], tag=Tag.COST_MATRIX), - "y": TaggedArray(data[tgt_mask, :][:, tgt_mask], tag=Tag.COST_MATRIX), - } + if term in ("x", "y"): + self._policy.create_mask(key_1, allow_empty=False) + return {term: TaggedArray(data[mask, :][:, mask], tag=Tag.COST_MATRIX)} + + raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.") diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 179beec27..c44ee511a 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -64,8 +64,8 @@ def prepare( """ self.batch_key = key if joint_attr is None: - kwargs["callback"] = "local-pca" - kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}} + kwargs["xy_callback"] = "local-pca" + kwargs["xy_callback_kwargs"] = {**kwargs.get("callback_kwargs", {})} elif isinstance(joint_attr, str): kwargs["xy"] = { "x_attr": "obsm", diff --git a/tests/analysis_mixins/test_base_analysis.py b/tests/analysis_mixins/test_base_analysis.py index e03d036dc..306c6fb14 100644 --- a/tests/analysis_mixins/test_base_analysis.py +++ b/tests/analysis_mixins/test_base_analysis.py @@ -25,7 +25,7 @@ def test_sample_from_tmap_pipeline( source_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10]) target_dim = len(gt_temporal_adata[gt_temporal_adata.obs["day"] == 10.5]) problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="sequential", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="sequential", xy_callback="local-pca") problem[10, 10.5]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) if interpolation_parameter is not None and not 0 <= interpolation_parameter <= 1: @@ -71,7 +71,7 @@ def test_sample_from_tmap_pipeline( def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool, scale_by_marginals: bool): problem = CompoundProblemWithMixin(gt_temporal_adata) problem = problem.prepare( - "day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", callback="local-pca" + "day", subset=[(10, 10.5), (10.5, 11), (10, 11)], policy="explicit", xy_callback="local-pca" ) problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) @@ -92,7 +92,7 @@ def test_cell_transition_aggregation_cell_forward(self, gt_temporal_adata: AnnDa key_2 = config["key_2"] config["key_3"] problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca") assert set(problem.problems.keys()) == {(key_1, key_2)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -133,7 +133,7 @@ def test_cell_transition_aggregation_cell_backward(self, gt_temporal_adata: AnnD key_1 = config["key_1"] key_2 = config["key_2"] problem = CompoundProblemWithMixin(gt_temporal_adata) - problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", callback="local-pca") + problem = problem.prepare("day", subset=[(10, 10.5)], policy="explicit", xy_callback="local-pca") assert set(problem.problems.keys()) == {(key_1, key_2)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index b1a110cfe..cb2630963 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -65,8 +65,8 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal xy={"x_attr": "X", "y_attr": "X"}, key="time", policy="sequential", - callback="local-pca", - callback_kwargs=callback_kwargs, + xy_callback="local-pca", + xy_callback_kwargs=callback_kwargs, ) assert isinstance(problem, CompoundProblem) @@ -76,7 +76,7 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal @pytest.mark.fast() def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): expected_keys = [(0, 1), (1, 2)] - spy = mocker.spy(TestCompoundProblem, "callback") + spy = mocker.spy(TestCompoundProblem, "xy_callback") problem = Problem(adata=adata_time) _ = problem.prepare( @@ -86,7 +86,7 @@ def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): key="time", policy="sequential", callback=TestCompoundProblem.callback, - callback_kwargs={"sentinel": True}, + xy_callback_kwargs={"sentinel": True}, ) assert spy.call_count == len(expected_keys) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index 05ab9dd01..a0fc9ee3d 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -157,7 +157,7 @@ def test_compute_interpolated_distance_pipeline(self, gt_temporal_adata: AnnData key, subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -187,7 +187,7 @@ def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnDa subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -215,7 +215,7 @@ def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnDat subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -241,7 +241,7 @@ def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData): subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) @@ -264,7 +264,7 @@ def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData): subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", scale_cost="mean", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 082306f2a..82c9c5ab5 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -205,7 +205,7 @@ def test_result_compares_to_wot(self, gt_temporal_adata: AnnData): key, subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], policy="explicit", - callback_kwargs={"n_comps": 50}, + xy_callback_kwargs={"n_comps": 50}, ) tp = tp.solve(epsilon=eps, scale_cost="mean", tau_a=lam1 / (lam1 + eps), tau_b=lam2 / (lam2 + eps)) From 786317952e99f4411e946e2eb8c4cc96f7913c70 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 09:46:27 +0100 Subject: [PATCH 05/14] incorporate requested changes --- src/moscot/problems/_utils.py | 4 +- src/moscot/problems/base/_base_problem.py | 2 +- src/moscot/problems/base/_compound_problem.py | 1 - src/moscot/problems/generic/_generic.py | 2 +- tests/problems/base/test_compound_problem.py | 38 +++++++++++++++---- 5 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index e369b9efa..2882fef28 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -72,9 +72,7 @@ def handle_joint_attr( if joint_attr is None: if "xy_callback" not in kwargs: kwargs["xy_callback"] = "local-pca" - else: - kwargs["xy_callback"] = kwargs["xy_callback"] - kwargs["xy_callback_kwargs"] = {**kwargs.get("xy_callback_kwargs", {})} + kwargs["xy_callback_kwargs"] = kwargs["xy_callback_kwargs"].set_default({}) return None, kwargs if isinstance(joint_attr, str): xy = { diff --git a/src/moscot/problems/base/_base_problem.py b/src/moscot/problems/base/_base_problem.py index e4be971dd..e58a682b3 100644 --- a/src/moscot/problems/base/_base_problem.py +++ b/src/moscot/problems/base/_base_problem.py @@ -445,7 +445,7 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: scaler = StandardScaler() if scale else None if term == "xy": - if adata_y is None: + if y is None: raise ValueError("When `term` is `xy` `adata_y` cannot be `None`.") n = x.shape[0] data = concat(x, y) diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index fe1ee083c..c61175e5b 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -653,7 +653,6 @@ def _cost_matrix_callback( return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)} if term in ("x", "y"): - self._policy.create_mask(key_1, allow_empty=False) return {term: TaggedArray(data[mask, :][:, mask], tag=Tag.COST_MATRIX)} raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.") diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index c44ee511a..a2c7711df 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -65,7 +65,7 @@ def prepare( self.batch_key = key if joint_attr is None: kwargs["xy_callback"] = "local-pca" - kwargs["xy_callback_kwargs"] = {**kwargs.get("callback_kwargs", {})} + kwargs["xy_callback_kwargs"] = kwargs["xy_callback_kwargs"].set_default({}) elif isinstance(joint_attr, str): kwargs["xy"] = { "x_attr": "obsm", diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index cb2630963..5289177ec 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Literal, Mapping +from typing import Any, Tuple, Literal, Mapping, Optional import os from pytest_mock import MockerFixture @@ -19,12 +19,36 @@ class TestCompoundProblem: @staticmethod - def callback( - adata: AnnData, adata_y: AnnData, sentinel: bool = False + def xy_callback( + term: Literal["x", "y", "xy"], + adata: AnnData, + adata_y: Optional[AnnData] = None, + sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: assert sentinel assert isinstance(adata_y, AnnData) - return {"xy": TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + + + def x_callback( + term: Literal["x", "y", "xy"], + adata: AnnData, + adata_y: Optional[AnnData] = None, + sentinel: bool = False + ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: + assert sentinel + assert isinstance(adata_y, AnnData) + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + + def y_callback( + term: Literal["x", "y", "xy"], + adata: AnnData, + adata_y: Optional[AnnData] = None, + sentinel: bool = False + ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: + assert sentinel + assert isinstance(adata_y, AnnData) + return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} def test_sc_pipeline(self, adata_time: AnnData): expected_keys = [(0, 1), (1, 2)] @@ -55,7 +79,7 @@ def test_sc_pipeline(self, adata_time: AnnData): @pytest.mark.fast() def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scale: bool): subproblem = OTProblem(adata_time, adata_tgt=adata_time.copy()) - callback_kwargs = {"n_comps": 5, "scale": scale} + xy_callback_kwargs = {"n_comps": 5, "scale": scale} spy = mocker.spy(subproblem, "_local_pca_callback") problem = Problem(adata_time) @@ -66,12 +90,12 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal key="time", policy="sequential", xy_callback="local-pca", - xy_callback_kwargs=callback_kwargs, + xy_callback_kwargs=xy_callback_kwargs, ) assert isinstance(problem, CompoundProblem) assert isinstance(problem.problems, dict) - spy.assert_called_with(subproblem.adata_src, subproblem.adata_tgt, **callback_kwargs) + spy.assert_called_with(subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) @pytest.mark.fast() def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): From 30ded876de6785f2a41a4562a35b233451293dc0 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 09:53:57 +0100 Subject: [PATCH 06/14] fix pre-commit --- src/moscot/problems/base/_compound_problem.py | 8 ++++++-- tests/problems/base/test_compound_problem.py | 16 +++------------- tests/problems/generic/test_fgw_problem.py | 4 +--- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index c61175e5b..81a08b790 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -48,7 +48,9 @@ K = TypeVar("K", bound=Hashable) B = TypeVar("B", bound=OTProblem) -Callback_t = Callable[[Literal["xy", "x", "y"], AnnData, AnnData], Mapping[Literal["xy", "x", "y"], TaggedArray]] +Callback_t = Callable[ + [Literal["xy", "x", "y"], AnnData, Optional[AnnData]], Mapping[Literal["xy", "x", "y"], TaggedArray] +] ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]] # TODO(michalk8): future behavior # ApplyOutput_t = Union[ArrayLike, Dict[Tuple[K, K], ArrayLike]] @@ -118,7 +120,7 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: if not callable(callback): raise TypeError("Callback is not a function.") - data = callback(term=term, adata=problem.adata_src, adata_y=problem.adata_tgt, **kwargs) + data = callback(term, problem.adata_src, problem.adata_tgt, **kwargs) verify_data(data) return data @@ -645,6 +647,8 @@ def _cost_matrix_callback( mask = self._policy.create_mask(key_1, allow_empty=False) if term == "xy": + if key_2 is None: + raise ValueError("If `term` is `xy`, `key_2` cannot be `None`.") mask_2 = self._policy.create_mask(key_2, allow_empty=False) linear_cost_matrix = data[mask, :][:, mask_2] diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 5289177ec..9b8d61c7a 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -20,31 +20,21 @@ class TestCompoundProblem: @staticmethod def xy_callback( - term: Literal["x", "y", "xy"], - adata: AnnData, - adata_y: Optional[AnnData] = None, - sentinel: bool = False + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: assert sentinel assert isinstance(adata_y, AnnData) return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} - def x_callback( - term: Literal["x", "y", "xy"], - adata: AnnData, - adata_y: Optional[AnnData] = None, - sentinel: bool = False + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: assert sentinel assert isinstance(adata_y, AnnData) return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} def y_callback( - term: Literal["x", "y", "xy"], - adata: AnnData, - adata_y: Optional[AnnData] = None, - sentinel: bool = False + term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: assert sentinel assert isinstance(adata_y, AnnData) diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index e8f4f3d72..dbb003f46 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -116,9 +116,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin assert getattr(geom, val) == args_to_check[arg] @pytest.mark.fast() - @pytest.mark.parametrize( - "cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)] - ) + @pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)]) def test_prepare_costs(self, adata_time: AnnData, cost: Tuple[str, Any]): problem = FGWProblem(adata=adata_time) problem = problem.prepare( From e1ea3b2fb2ef5a9c6f9f05f2a35f959581ddf7ce Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 10:30:03 +0100 Subject: [PATCH 07/14] fix pre-commit --- src/moscot/problems/_utils.py | 4 ++-- src/moscot/problems/generic/_generic.py | 2 +- tests/problems/base/test_compound_problem.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 2882fef28..034738db0 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -67,12 +67,12 @@ def wrap_solve( def handle_joint_attr( - joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Any + joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[Any, Any] ) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]: if joint_attr is None: if "xy_callback" not in kwargs: kwargs["xy_callback"] = "local-pca" - kwargs["xy_callback_kwargs"] = kwargs["xy_callback_kwargs"].set_default({}) + kwargs.setdefault("xy_callback_kwargs", {}) return None, kwargs if isinstance(joint_attr, str): xy = { diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index a2c7711df..30bc9edbb 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -65,7 +65,7 @@ def prepare( self.batch_key = key if joint_attr is None: kwargs["xy_callback"] = "local-pca" - kwargs["xy_callback_kwargs"] = kwargs["xy_callback_kwargs"].set_default({}) + kwargs.setdefault("xy_callback_kwargs", {}) elif isinstance(joint_attr, str): kwargs["xy"] = { "x_attr": "obsm", diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 9b8d61c7a..a757def59 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -99,7 +99,7 @@ def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): y={"attr": "X"}, key="time", policy="sequential", - callback=TestCompoundProblem.callback, + callback=TestCompoundProblem.callback_xy, xy_callback_kwargs={"sentinel": True}, ) From 7c41ef602d8659000e6e2c0b6fc01607f356e4aa Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 11:16:32 +0100 Subject: [PATCH 08/14] fix test --- tests/problems/base/test_compound_problem.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index a757def59..21c921b3e 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -26,6 +26,7 @@ def xy_callback( assert isinstance(adata_y, AnnData) return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + @staticmethod def x_callback( term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: @@ -33,6 +34,7 @@ def x_callback( assert isinstance(adata_y, AnnData) return {term: TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX)} + @staticmethod def y_callback( term: Literal["x", "y", "xy"], adata: AnnData, adata_y: Optional[AnnData] = None, sentinel: bool = False ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: @@ -85,7 +87,7 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal assert isinstance(problem, CompoundProblem) assert isinstance(problem.problems, dict) - spy.assert_called_with(subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) + spy.assert_called_with("xy", subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) @pytest.mark.fast() def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): @@ -99,7 +101,7 @@ def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): y={"attr": "X"}, key="time", policy="sequential", - callback=TestCompoundProblem.callback_xy, + callback=TestCompoundProblem.xy_callback, xy_callback_kwargs={"sentinel": True}, ) From 8074ee71cfa6329cd3fd4d060ec22fb44e6ef7ef Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 11:19:43 +0100 Subject: [PATCH 09/14] fix test --- tests/problems/base/test_compound_problem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 21c921b3e..ad7091f59 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -96,12 +96,12 @@ def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): problem = Problem(adata=adata_time) _ = problem.prepare( - xy={"x_attr": "X", "y_attr": "X"}, + xy=None, x={"attr": "X"}, y={"attr": "X"}, key="time", policy="sequential", - callback=TestCompoundProblem.xy_callback, + xy_callback=TestCompoundProblem.xy_callback, xy_callback_kwargs={"sentinel": True}, ) From 0d0f33a7a2444440c1c117ddfa8679a1e4f2f274 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 11:21:24 +0100 Subject: [PATCH 10/14] add test for quad custom callback --- tests/problems/base/test_compound_problem.py | 24 +++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index ad7091f59..b11eb7d3a 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -90,7 +90,7 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal spy.assert_called_with("xy", subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) @pytest.mark.fast() - def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): + def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture): expected_keys = [(0, 1), (1, 2)] spy = mocker.spy(TestCompoundProblem, "xy_callback") @@ -107,6 +107,28 @@ def test_custom_callback(self, adata_time: AnnData, mocker: MockerFixture): assert spy.call_count == len(expected_keys) + @pytest.mark.fast() + def test_custom_callback_quad(self, adata_time: AnnData, mocker: MockerFixture): + expected_keys = [(0, 1), (1, 2)] + spy_x = mocker.spy(TestCompoundProblem, "x_callback") + spy_y = mocker.spy(TestCompoundProblem, "y_callback") + + problem = Problem(adata=adata_time) + _ = problem.prepare( + xy=None, + x={"attr": "X"}, + y={"attr": "X"}, + key="time", + policy="sequential", + x_callback=TestCompoundProblem.x_callback, + y_callback=TestCompoundProblem.y_callback, + x_callback_kwargs={"sentinel": True}, + y_callback_kwargs={"sentinel": True}, + ) + + assert spy_x.call_count == len(expected_keys) + assert spy_y.call_count == len(expected_keys) + def test_different_passings_linear(self, adata_with_cost_matrix: AnnData): epsilon = 5 xy = {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"} From f8910a522b0a61a4cef8d4bace7765e25de1fb15 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 11:26:37 +0100 Subject: [PATCH 11/14] adapt kwargs for callback --- src/moscot/_docs/_docs.py | 38 +++++++++++++++---- src/moscot/problems/base/_compound_problem.py | 8 +++- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/moscot/_docs/_docs.py b/src/moscot/_docs/_docs.py index b92a06c7f..29727d604 100644 --- a/src/moscot/_docs/_docs.py +++ b/src/moscot/_docs/_docs.py @@ -30,13 +30,31 @@ reference `reference` in :class:`moscot.problems._subset_policy.StarPolicy`. """ -_callback = """\ -callback - Custom callback applied to each distribution as pre-processing step. Examples are given in TODO Link Notebook. +_xy_callback = """\ +xy_callback + Custom callback applied to the linear term as pre-processing step. Examples are given in TODO Link Notebook. """ -_callback_kwargs = """\ -callback_kwargs - Keyword arguments for `callback`. +_xy_callback_kwargs = """\ +xy_callback_kwargs + Keyword arguments for `xy_callback`. +""" +_x_callback = """\ +x_callback + Custom callback applied to the source distribution of the quadratic term as pre-processing step. + Examples are given in TODO Link Notebook. +""" +_x_callback_kwargs = """\ +x_callback_kwargs + Keyword arguments for `x_callback`. +""" +_y_callback = """\ +y_callback + Custom callback applied to the target distribution of the quadratic term as pre-processing step. + Examples are given in TODO Link Notebook. +""" +_y_callback_kwargs = """\ +x_callback_kwargs + Keyword arguments for `y_callback`. """ _epsilon = """\ epsilon @@ -362,8 +380,12 @@ source=_source, target=_target, reference=_reference, - callback=_callback, - callback_kwargs=_callback_kwargs, + xy_callback=_xy_callback, + xy_callback_kwargs=_xy_callback_kwargs, + x_callback=_x_callback, + x_callback_kwargs=_x_callback_kwargs, + y_callback=_y_callback, + y_callback_kwargs=_y_callback_kwargs, epsilon=_epsilon, alpha=_alpha, tau_a=_tau_a, diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index 81a08b790..2d44309a6 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -202,8 +202,12 @@ def prepare( %(policy)s %(subset)s %(reference)s - %(callback)s - %(callback_kwargs)s + %(xy_callback)s + %(x_callback)s + %(y_callback)s + %(xy_callback_kwargs)s + %(x_callback_kwargs)s + %(y_callback_kwargs)s %(a)s %(b)s From 0ea9420a76952b8bfc6a1b5f4ced533e82ceeeb7 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 13:31:04 +0100 Subject: [PATCH 12/14] fix typing --- src/moscot/problems/_utils.py | 16 ++++++++++++++-- src/moscot/problems/base/_compound_problem.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 034738db0..0085c1e44 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -67,8 +67,8 @@ def wrap_solve( def handle_joint_attr( - joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[Any, Any] -) -> Tuple[Optional[Mapping[str, Any]], Dict[str, Any]]: + joint_attr: Optional[Union[str, Mapping[str, Any]]], kwargs: Dict[str, Any] +) -> Tuple[Optional[Dict[str, Any]], Dict[str, Any]]: if joint_attr is None: if "xy_callback" not in kwargs: kwargs["xy_callback"] = "local-pca" @@ -83,6 +83,18 @@ def handle_joint_attr( } return xy, kwargs if isinstance(joint_attr, Mapping): + joint_attr = dict(joint_attr) + if "tag" not in joint_attr: + raise KeyError("When providing a `dict` as `joint_attr`, the key `tag` is required.") + if "key" not in joint_attr: + raise KeyError("When providing a `dict` as `joint_attr`, the key `key` is required.") + if joint_attr["tag"] == "cost_matrix": + if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": + joint_attr.setdefault("cost", "custom") + joint_attr.setdefault("attr", "obsp") + kwargs["xy_callback"] = "cost-matrix" + kwargs.setdefault("xy_callback_kwargs", {"key": joint_attr["key"]}) + kwargs.setdefault("xy_callback_kwargs", {}) return joint_attr, kwargs raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.") diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index 2d44309a6..d1ec3f5d4 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -630,9 +630,9 @@ def _callback_handler( **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: # TODO(michalk8): better name? + if callback == "cost-matrix": return self._cost_matrix_callback(term=term, key_1=key_1, key_2=key_2, **kwargs) - return super()._callback_handler( term=term, key_1=key_1, key_2=key_2, problem=problem, callback=callback, **kwargs ) From 778a8059d745747d44e1879406e4c92b233d0cba Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 14:24:23 +0100 Subject: [PATCH 13/14] fix handle_joint_attr --- src/moscot/problems/_utils.py | 18 +++++++++++++----- tests/problems/space/test_alignment_problem.py | 6 ++---- tests/problems/space/test_mapping_problem.py | 4 +--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 0085c1e44..64ce1dc24 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -84,11 +84,19 @@ def handle_joint_attr( return xy, kwargs if isinstance(joint_attr, Mapping): joint_attr = dict(joint_attr) - if "tag" not in joint_attr: - raise KeyError("When providing a `dict` as `joint_attr`, the key `tag` is required.") - if "key" not in joint_attr: - raise KeyError("When providing a `dict` as `joint_attr`, the key `key` is required.") - if joint_attr["tag"] == "cost_matrix": + if "attr" in joint_attr and joint_attr["attr"] == "X": + return {"x_attr": "X", "y_attr": "X"}, kwargs + if "attr" in joint_attr and joint_attr["attr"] == "obsm": + if "key" not in joint_attr: + raise KeyError("`key` must be provided when `attr` is `obsm`.") + xy = { + "x_attr": "obsm", + "x_key": joint_attr["key"], + "y_attr": "obsm", + "y_key": joint_attr["key"], + } + return xy, kwargs + if joint_attr.get("tag", None) == "cost_matrix": if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": joint_attr.setdefault("cost", "custom") joint_attr.setdefault("attr", "obsp") diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index a83bb4d3a..5ddb67539 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -25,9 +25,7 @@ class TestAlignmentProblem: @pytest.mark.fast() - @pytest.mark.parametrize( - "joint_attr", [{"x_attr": "X", "y_attr": "X"}] - ) # TODO(giovp): check that callback is correct + @pytest.mark.parametrize("joint_attr", [{"attr": "X"}]) # TODO(giovp): check that callback is correct def test_prepare_sequential(self, adata_space_rotate: AnnData, joint_attr: Optional[Mapping[str, Any]]): n_obs = adata_space_rotate.shape[0] // 3 # adata is made of 3 datasets n_var = adata_space_rotate.shape[1] @@ -112,7 +110,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))] key = ("0", "1") problem = AlignmentProblem(adata=adata_space_rotate) - problem = problem.prepare(batch_key="batch", joint_attr={"x_attr": "X", "y_attr": "X"}) + problem = problem.prepare(batch_key="batch", joint_attr={"attr": "X"}) problem = problem.solve(**args_to_check) solver = problem[key].solver.solver diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 8889275cd..b8fdfea55 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -27,9 +27,7 @@ class TestMappingProblem: @pytest.mark.fast() @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) - @pytest.mark.parametrize( - "joint_attr", [None, "X_pca", {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}] - ) + @pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}]) def test_prepare(self, adata_mapping: AnnData, sc_attr: Mapping[str, str], joint_attr: Optional[Mapping[str, str]]): adataref, adatasp = _adata_spatial_split(adata_mapping) expected_keys = {(i, "ref") for i in adatasp.obs.batch.cat.categories} From 8178df9941969d76c9b61e673c73cfbf21f1d40c Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Fri, 9 Dec 2022 15:23:21 +0100 Subject: [PATCH 14/14] incorporate requested changes --- src/moscot/problems/_utils.py | 10 ++--- src/moscot/problems/base/_base_problem.py | 2 +- src/moscot/problems/base/_compound_problem.py | 44 ++++++++++--------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index 64ce1dc24..8417e1c9a 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -82,11 +82,11 @@ def handle_joint_attr( "y_key": joint_attr, } return xy, kwargs - if isinstance(joint_attr, Mapping): + if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space joint_attr = dict(joint_attr) - if "attr" in joint_attr and joint_attr["attr"] == "X": + if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud return {"x_attr": "X", "y_attr": "X"}, kwargs - if "attr" in joint_attr and joint_attr["attr"] == "obsm": + if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud if "key" not in joint_attr: raise KeyError("`key` must be provided when `attr` is `obsm`.") xy = { @@ -96,8 +96,8 @@ def handle_joint_attr( "y_key": joint_attr["key"], } return xy, kwargs - if joint_attr.get("tag", None) == "cost_matrix": - if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": + if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost + if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix joint_attr.setdefault("cost", "custom") joint_attr.setdefault("attr", "obsp") kwargs["xy_callback"] = "cost-matrix" diff --git a/src/moscot/problems/base/_base_problem.py b/src/moscot/problems/base/_base_problem.py index e58a682b3..cb77dcd22 100644 --- a/src/moscot/problems/base/_base_problem.py +++ b/src/moscot/problems/base/_base_problem.py @@ -458,7 +458,7 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike: if scaler is not None: data = scaler.fit_transform(data) return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)} - if term in ("x", "y"): + if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`") x = sc.pp.pca(x, n_comps=n_comps, **kwargs) if scaler is not None: diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index d1ec3f5d4..8b8ca61ca 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -104,7 +104,7 @@ def _callback_handler( key_2: K, problem: B, *, - callback: Union[Literal["local-pca"], Callback_t], + callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: @@ -115,6 +115,8 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None: if not isinstance(val, TaggedArray): raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.") + if callback is None: + return {} if callback == "local-pca": callback = problem._local_pca_callback @@ -150,24 +152,19 @@ def _create_problems( tgt_name = tgt problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) - if xy_callback is not None: - xy_data = self._callback_handler( - term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs - ) - else: - xy_data = {} - if x_callback is not None: - x_data = self._callback_handler( - term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs - ) - else: - x_data = {} - if y_callback is not None: - y_data = self._callback_handler( - term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs - ) - else: - y_data = {} + + xy_data = self._callback_handler( + term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs + ) + + x_data = self._callback_handler( + term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs + ) + + y_data = self._callback_handler( + term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs + ) + kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type] if isinstance(problem, BirthDeathProblem): @@ -626,7 +623,7 @@ def _callback_handler( key_2: K, problem: B, *, - callback: Union[Literal["local-pca", "cost-matrix"], Callback_t], + callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None, **kwargs: Any, ) -> Mapping[Literal["xy", "x", "y"], TaggedArray]: # TODO(michalk8): better name? @@ -657,10 +654,15 @@ def _cost_matrix_callback( linear_cost_matrix = data[mask, :][:, mask_2] if issparse(linear_cost_matrix): + logger.warning("Linear cost matrix being densified.") linear_cost_matrix = linear_cost_matrix.A return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)} if term in ("x", "y"): - return {term: TaggedArray(data[mask, :][:, mask], tag=Tag.COST_MATRIX)} + quad_cost_matrix = data[mask, :][:, mask_2] + if issparse(quad_cost_matrix): + logger.warning("Quadratic cost matrix being densified.") + quad_cost_matrix = quad_cost_matrix.A + return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)} raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")