diff --git a/src/moscot/problems/base/_mixins.py b/src/moscot/problems/base/_mixins.py index 8343e5c6b..603ae9f35 100644 --- a/src/moscot/problems/base/_mixins.py +++ b/src/moscot/problems/base/_mixins.py @@ -61,7 +61,6 @@ def _apply( def _interpolate_transport( self: "AnalysisMixinProtocol[K, B]", path: Sequence[Tuple[K, K]], - forward: bool = True, scale_by_marginals: bool = True, ) -> LinearOperator: ... @@ -352,7 +351,6 @@ def _interpolate_transport( self: AnalysisMixinProtocol[K, B], # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) path: Sequence[Tuple[K, K]], - forward: bool = True, scale_by_marginals: bool = True, **_: Any, ) -> LinearOperator: @@ -361,9 +359,7 @@ def _interpolate_transport( assert isinstance(self._policy, SubsetPolicy) # TODO(@MUCDK, @giovp, discuss what exactly this function should do, seems like it could be more generic) fst, *rest = path - return self.solutions[fst].chain( - [self.solutions[r] for r in rest], forward=forward, scale_by_marginals=scale_by_marginals - ) + return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 5ac008105..8522c0842 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -52,12 +52,14 @@ def _interpolate_scheme( # type:ignore[empty-body] @staticmethod def _affine( # type:ignore[empty-body] - tmap: LinearOperator, tgt: ArrayLike, src: ArrayLike + tmap: LinearOperator, src: ArrayLike, tgt: ArrayLike ) -> Tuple[ArrayLike, ArrayLike]: ... @staticmethod - def _warp(tmap: LinearOperator, _: ArrayLike, src: ArrayLike) -> Tuple[ArrayLike, None]: # type:ignore[empty-body] + def _warp( # type: ignore[empty-body] + tmap: LinearOperator, src: ArrayLike, _: ArrayLike + ) -> Tuple[ArrayLike, Optional[ArrayLike]]: ... def _cell_transition( @@ -122,12 +124,6 @@ def _interpolate_scheme( reference_ = reference full_steps = self._policy._graph starts = set(chain.from_iterable(full_steps)) - set(reference_) # type: ignore[arg-type] - fwd_steps, bwd_steps = {}, {} - for start in starts: - try: - fwd_steps[(start, reference)] = self._policy.plan(start=start, end=reference) - except NetworkXNoPath: - bwd_steps[(reference, start)] = self._policy.plan(start=reference, end=start) if mode == AlignmentMode.AFFINE: _transport = self._affine @@ -136,23 +132,21 @@ def _interpolate_scheme( else: raise NotImplementedError(f"Alignment mode `{mode!r}` is not yet implemented.") - if len(fwd_steps): - for (start, _), path in fwd_steps.items(): - tmap = self._interpolate_transport(path=path, scale_by_marginals=True, forward=True) - transport_maps[start], transport_metadata[start] = _transport( - tmap, self._subset_spatial(start, spatial_key=spatial_key), src, forward=True - ) + steps = {} + for start in starts: + try: + steps[start, reference, True] = self._policy.plan(start=start, end=reference) + except NetworkXNoPath: + steps[reference, start, False] = self._policy.plan(start=reference, end=start) - if len(bwd_steps): - for (_, end), path in bwd_steps.items(): - tmap = self._interpolate_transport(path=path, scale_by_marginals=True, forward=True) - transport_maps[end], transport_metadata[end] = _transport( - tmap, self._subset_spatial(end, spatial_key=spatial_key), src, forward=False - ) + for (start, end, forward), path in steps.items(): + tmap = self._interpolate_transport(path=path, scale_by_marginals=True) + # make `tmap` to have shape `(m, n_ref)` and apply it to `src` of shape `(n_ref, d)` + key, tmap = (start, tmap) if forward else (end, tmap.T) + spatial_data = self._subset_spatial(key, spatial_key=spatial_key) + transport_maps[key], transport_metadata[key] = _transport(tmap, src=src, tgt=spatial_data) - if mode == "affine": - return transport_maps, transport_metadata - return transport_maps, None + return transport_maps, (transport_metadata if mode == "affine" else None) @d.dedent def align( @@ -282,18 +276,17 @@ def _subset_spatial( ) -> ArrayLike: if spatial_key is None: spatial_key = self.spatial_key - return self.adata[self.adata.obs[self._policy._subset_key] == k].obsm[spatial_key].astype(np.float_, copy=True) + return self.adata[self.adata.obs[self._policy._subset_key] == k].obsm[spatial_key].astype(float, copy=True) @staticmethod def _affine( - tmap: LinearOperator, tgt: ArrayLike, src: ArrayLike, forward: bool = True, *args: Any + tmap: LinearOperator, + src: ArrayLike, + tgt: ArrayLike, ) -> Tuple[ArrayLike, ArrayLike]: """Affine transformation.""" tgt -= tgt.mean(0) - if forward: - out = tmap.dot(src) - else: - out = tmap.T.dot(src) + out = tmap @ src H = tgt.T.dot(out) U, _, Vt = svd(H) R = Vt.T.dot(U.T) @@ -301,11 +294,10 @@ def _affine( return tgt, R @staticmethod - def _warp(tmap: LinearOperator, _: ArrayLike, src: ArrayLike, forward: bool = True) -> Tuple[ArrayLike, None]: + def _warp(tmap: LinearOperator, src: ArrayLike, tgt: ArrayLike) -> Tuple[ArrayLike, None]: """Warp transformation.""" - if forward: - return tmap.dot(src), None - return tmap.T.dot(src), None + del tgt + return tmap @ src, None class SpatialMappingMixin(AnalysisMixin[K, B]): diff --git a/src/moscot/solvers/_output.py b/src/moscot/solvers/_output.py index 6d5de06c5..8a37afb53 100644 --- a/src/moscot/solvers/_output.py +++ b/src/moscot/solvers/_output.py @@ -124,13 +124,11 @@ def pull(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike: x = self._scale_by_marginals(x, forward=False) return self._apply(x, forward=False) - def as_linear_operator(self, *, forward: bool, scale_by_marginals: bool = False) -> LinearOperator: + def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator: """Transform :attr:`transport_matrix` into a linear operator. Parameters ---------- - forward - If `True`, convert the :meth:`push` operator, else the :meth:`pull` operator. scale_by_marginals Whether to scale by marginals. @@ -140,20 +138,17 @@ def as_linear_operator(self, *, forward: bool, scale_by_marginals: bool = False) """ push = partial(self.push, scale_by_marginals=scale_by_marginals) pull = partial(self.pull, scale_by_marginals=scale_by_marginals) - mv, rmv = (pull, push) if forward else (push, pull) # please do not change this line - return LinearOperator(shape=self.shape, dtype=self.a.dtype, matvec=mv, rmatvec=rmv) + # push: a @ X (rmatvec) + # pull: X @ a (matvec) + return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push) - def chain( - self, outputs: Iterable["BaseSolverOutput"], forward: bool, scale_by_marginals: bool = False - ) -> LinearOperator: + def chain(self, outputs: Iterable["BaseSolverOutput"], scale_by_marginals: bool = False) -> LinearOperator: """Chain subsequent applications of :attr:`transport_matrix`. Parameters ---------- outputs Sequence of transport matrices to chain. - forward - If `True`, chain the :meth:`push` operator, else the :meth:`pull` operator. scale_by_marginals Whether to scale by marginals. @@ -161,9 +156,9 @@ def chain( ------- The chained transport matrices as a linear operator. """ - op = self.as_linear_operator(forward=forward, scale_by_marginals=scale_by_marginals) + op = self.as_linear_operator(scale_by_marginals) for out in outputs: - op *= out.as_linear_operator(forward=forward, scale_by_marginals=scale_by_marginals) + op *= out.as_linear_operator(scale_by_marginals) return op diff --git a/tests/analysis_mixins/test_base_analysis.py b/tests/analysis_mixins/test_base_analysis.py index 73eb4df30..b5a10bcb2 100644 --- a/tests/analysis_mixins/test_base_analysis.py +++ b/tests/analysis_mixins/test_base_analysis.py @@ -76,9 +76,7 @@ def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool, problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) problem[(10.0, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"]) - tmap = problem._interpolate_transport( - [(10, 11)], forward=forward, scale_by_marginals=True, explicit_steps=[(10.0, 11.0)] - ) + tmap = problem._interpolate_transport([(10, 11)], scale_by_marginals=True, explicit_steps=[(10.0, 11.0)]) assert isinstance(tmap, LinearOperator) # TODO(@MUCDK) add regression test after discussing with @giovp what this function should be