Skip to content

Commit

Permalink
Simplify linear operator (#431)
Browse files Browse the repository at this point in the history
* Simplify linear operator

* Simplify `align`, fix test
  • Loading branch information
michalk8 authored and lucaeyring committed Mar 15, 2023
1 parent 4f1d26e commit 0ffc730
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 53 deletions.
6 changes: 1 addition & 5 deletions src/moscot/problems/base/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _apply(
def _interpolate_transport(
self: "AnalysisMixinProtocol[K, B]",
path: Sequence[Tuple[K, K]],
forward: bool = True,
scale_by_marginals: bool = True,
) -> LinearOperator:
...
Expand Down Expand Up @@ -352,7 +351,6 @@ def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[Tuple[K, K]],
forward: bool = True,
scale_by_marginals: bool = True,
**_: Any,
) -> LinearOperator:
Expand All @@ -361,9 +359,7 @@ def _interpolate_transport(
assert isinstance(self._policy, SubsetPolicy)
# TODO(@MUCDK, @giovp, discuss what exactly this function should do, seems like it could be more generic)
fst, *rest = path
return self.solutions[fst].chain(
[self.solutions[r] for r in rest], forward=forward, scale_by_marginals=scale_by_marginals
)
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: Dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
Expand Down
58 changes: 25 additions & 33 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def _interpolate_scheme( # type:ignore[empty-body]

@staticmethod
def _affine( # type:ignore[empty-body]
tmap: LinearOperator, tgt: ArrayLike, src: ArrayLike
tmap: LinearOperator, src: ArrayLike, tgt: ArrayLike
) -> Tuple[ArrayLike, ArrayLike]:
...

@staticmethod
def _warp(tmap: LinearOperator, _: ArrayLike, src: ArrayLike) -> Tuple[ArrayLike, None]: # type:ignore[empty-body]
def _warp( # type: ignore[empty-body]
tmap: LinearOperator, src: ArrayLike, _: ArrayLike
) -> Tuple[ArrayLike, Optional[ArrayLike]]:
...

def _cell_transition(
Expand Down Expand Up @@ -122,12 +124,6 @@ def _interpolate_scheme(
reference_ = reference
full_steps = self._policy._graph
starts = set(chain.from_iterable(full_steps)) - set(reference_) # type: ignore[arg-type]
fwd_steps, bwd_steps = {}, {}
for start in starts:
try:
fwd_steps[(start, reference)] = self._policy.plan(start=start, end=reference)
except NetworkXNoPath:
bwd_steps[(reference, start)] = self._policy.plan(start=reference, end=start)

if mode == AlignmentMode.AFFINE:
_transport = self._affine
Expand All @@ -136,23 +132,21 @@ def _interpolate_scheme(
else:
raise NotImplementedError(f"Alignment mode `{mode!r}` is not yet implemented.")

if len(fwd_steps):
for (start, _), path in fwd_steps.items():
tmap = self._interpolate_transport(path=path, scale_by_marginals=True, forward=True)
transport_maps[start], transport_metadata[start] = _transport(
tmap, self._subset_spatial(start, spatial_key=spatial_key), src, forward=True
)
steps = {}
for start in starts:
try:
steps[start, reference, True] = self._policy.plan(start=start, end=reference)
except NetworkXNoPath:
steps[reference, start, False] = self._policy.plan(start=reference, end=start)

if len(bwd_steps):
for (_, end), path in bwd_steps.items():
tmap = self._interpolate_transport(path=path, scale_by_marginals=True, forward=True)
transport_maps[end], transport_metadata[end] = _transport(
tmap, self._subset_spatial(end, spatial_key=spatial_key), src, forward=False
)
for (start, end, forward), path in steps.items():
tmap = self._interpolate_transport(path=path, scale_by_marginals=True)
# make `tmap` to have shape `(m, n_ref)` and apply it to `src` of shape `(n_ref, d)`
key, tmap = (start, tmap) if forward else (end, tmap.T)
spatial_data = self._subset_spatial(key, spatial_key=spatial_key)
transport_maps[key], transport_metadata[key] = _transport(tmap, src=src, tgt=spatial_data)

if mode == "affine":
return transport_maps, transport_metadata
return transport_maps, None
return transport_maps, (transport_metadata if mode == "affine" else None)

@d.dedent
def align(
Expand Down Expand Up @@ -282,30 +276,28 @@ def _subset_spatial(
) -> ArrayLike:
if spatial_key is None:
spatial_key = self.spatial_key
return self.adata[self.adata.obs[self._policy._subset_key] == k].obsm[spatial_key].astype(np.float_, copy=True)
return self.adata[self.adata.obs[self._policy._subset_key] == k].obsm[spatial_key].astype(float, copy=True)

@staticmethod
def _affine(
tmap: LinearOperator, tgt: ArrayLike, src: ArrayLike, forward: bool = True, *args: Any
tmap: LinearOperator,
src: ArrayLike,
tgt: ArrayLike,
) -> Tuple[ArrayLike, ArrayLike]:
"""Affine transformation."""
tgt -= tgt.mean(0)
if forward:
out = tmap.dot(src)
else:
out = tmap.T.dot(src)
out = tmap @ src
H = tgt.T.dot(out)
U, _, Vt = svd(H)
R = Vt.T.dot(U.T)
tgt = R.dot(tgt.T).T
return tgt, R

@staticmethod
def _warp(tmap: LinearOperator, _: ArrayLike, src: ArrayLike, forward: bool = True) -> Tuple[ArrayLike, None]:
def _warp(tmap: LinearOperator, src: ArrayLike, tgt: ArrayLike) -> Tuple[ArrayLike, None]:
"""Warp transformation."""
if forward:
return tmap.dot(src), None
return tmap.T.dot(src), None
del tgt
return tmap @ src, None


class SpatialMappingMixin(AnalysisMixin[K, B]):
Expand Down
19 changes: 7 additions & 12 deletions src/moscot/solvers/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,11 @@ def pull(self, x: ArrayLike, scale_by_marginals: bool = False) -> ArrayLike:
x = self._scale_by_marginals(x, forward=False)
return self._apply(x, forward=False)

def as_linear_operator(self, *, forward: bool, scale_by_marginals: bool = False) -> LinearOperator:
def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator:
"""Transform :attr:`transport_matrix` into a linear operator.
Parameters
----------
forward
If `True`, convert the :meth:`push` operator, else the :meth:`pull` operator.
scale_by_marginals
Whether to scale by marginals.
Expand All @@ -140,30 +138,27 @@ def as_linear_operator(self, *, forward: bool, scale_by_marginals: bool = False)
"""
push = partial(self.push, scale_by_marginals=scale_by_marginals)
pull = partial(self.pull, scale_by_marginals=scale_by_marginals)
mv, rmv = (pull, push) if forward else (push, pull) # please do not change this line
return LinearOperator(shape=self.shape, dtype=self.a.dtype, matvec=mv, rmatvec=rmv)
# push: a @ X (rmatvec)
# pull: X @ a (matvec)
return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push)

def chain(
self, outputs: Iterable["BaseSolverOutput"], forward: bool, scale_by_marginals: bool = False
) -> LinearOperator:
def chain(self, outputs: Iterable["BaseSolverOutput"], scale_by_marginals: bool = False) -> LinearOperator:
"""Chain subsequent applications of :attr:`transport_matrix`.
Parameters
----------
outputs
Sequence of transport matrices to chain.
forward
If `True`, chain the :meth:`push` operator, else the :meth:`pull` operator.
scale_by_marginals
Whether to scale by marginals.
Returns
-------
The chained transport matrices as a linear operator.
"""
op = self.as_linear_operator(forward=forward, scale_by_marginals=scale_by_marginals)
op = self.as_linear_operator(scale_by_marginals)
for out in outputs:
op *= out.as_linear_operator(forward=forward, scale_by_marginals=scale_by_marginals)
op *= out.as_linear_operator(scale_by_marginals)

return op

Expand Down
4 changes: 1 addition & 3 deletions tests/analysis_mixins/test_base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def test_interpolate_transport(self, gt_temporal_adata: AnnData, forward: bool,
problem[(10.0, 10.5)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"])
problem[(10.5, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"])
problem[(10.0, 11.0)]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"])
tmap = problem._interpolate_transport(
[(10, 11)], forward=forward, scale_by_marginals=True, explicit_steps=[(10.0, 11.0)]
)
tmap = problem._interpolate_transport([(10, 11)], scale_by_marginals=True, explicit_steps=[(10.0, 11.0)])

assert isinstance(tmap, LinearOperator)
# TODO(@MUCDK) add regression test after discussing with @giovp what this function should be
Expand Down

0 comments on commit 0ffc730

Please sign in to comment.