Skip to content

Commit

Permalink
Explicitly jit the solvers (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 authored and lucaeyring committed Mar 15, 2023
1 parent 0ffc730 commit cee2466
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 37 deletions.
14 changes: 11 additions & 3 deletions src/moscot/backends/ott/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,19 @@ def __call__(self, **kwargs: Any) -> CostFn:


class OTTJaxSolver(OTSolver[OTTOutput]):
"""Base class for :mod:`ott` solvers :cite:`cuturi2022optimal`."""
"""Base class for :mod:`ott` solvers :cite:`cuturi2022optimal`.
def __init__(self):
Parameters
----------
jit
Whether to jit the :attr:`solver`.
"""

def __init__(self, jit: bool = True):
super().__init__()
self._solver: Optional[Union[Sinkhorn, LRSinkhorn, GromovWasserstein]] = None
self._problem: Optional[Union[LinearProblem, QuadraticProblem]] = None
self._jit = jit

def _create_geometry(
self,
Expand Down Expand Up @@ -88,7 +95,8 @@ def _solve( # type: ignore[override]
prob: Union[LinearProblem, QuadraticProblem],
**kwargs: Any,
) -> OTTOutput:
out = self.solver(prob, **kwargs)
solver = jax.jit(self.solver) if self._jit else self._solver
out = solver(prob, **kwargs) # type: ignore[misc]
return OTTOutput(out)

@staticmethod
Expand Down
51 changes: 20 additions & 31 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class TestSinkhorn:
@pytest.mark.parametrize("jit", [False, True])
@pytest.mark.parametrize("eps", [None, 1e-2, 1e-1])
def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool) -> None:
gt = sinkhorn(PointCloud(x, epsilon=eps), jit=jit)
fn = jax.jit(sinkhorn) if jit else sinkhorn
gt = fn(PointCloud(x, epsilon=eps))
solver = SinkhornSolver(jit=jit)
assert solver.xy is None
assert isinstance(solver.solver, Sinkhorn)
Expand Down Expand Up @@ -66,18 +67,16 @@ class TestGW:
@pytest.mark.parametrize("eps", [5e-2, 1e-2, 1e-1])
def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool) -> None:
thresh = 1e-2
kwargs = {"epsilon": eps, "threshold": thresh, "jit": jit}
kwargs["tags"] = {"x": "point_cloud", "y": "point_cloud"}
gt = gromov_wasserstein(
PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps), threshold=thresh, jit=jit, epsilon=eps
)
pc_x, pc_y = PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)
fn = jax.jit(gromov_wasserstein, static_argnames=["threshold", "epsilon"]) if jit else gromov_wasserstein
gt = fn(pc_x, pc_y, threshold=thresh, epsilon=eps)

solver = GWSolver(**kwargs)
solver = GWSolver(jit=jit, epsilon=eps, threshold=thresh)
assert isinstance(solver.solver, GromovWasserstein)
assert solver.x is None
assert solver.y is None

pred = solver(x=x, y=y, **kwargs)
pred = solver(x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"})

assert solver.rank == -1
assert not solver.is_low_rank
Expand All @@ -89,16 +88,13 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool
@pytest.mark.parametrize("eps", [5e-1, 1])
def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[float]) -> None:
thresh = 1e-3
kwargs = {"epsilon": eps, "threshold": thresh}
kwargs["tags"] = {"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX}

problem = QuadraticProblem(
geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps), geom_yy=Geometry(cost_matrix=y_cost, epsilon=eps)
)
gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem)
solver = GWSolver(**kwargs)
solver = GWSolver(epsilon=eps, threshold=thresh)

pred = solver(x=x_cost, y=y_cost, **kwargs)
pred = solver(x=x_cost, y=y_cost, tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX})

assert pred.rank == -1
assert solver.rank == -1
Expand All @@ -109,14 +105,12 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f
@pytest.mark.parametrize("rank", [-1, 7])
def test_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
thresh, eps = 1e-2, 1e-2
kwargs = {"epsilon": eps, "threshold": thresh, "rank": rank}
kwargs["tags"] = {"x": "point_cloud", "y": "point_cloud"}

gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)
solver = GWSolver(**kwargs)
pred = solver(x=x, y=y, **kwargs)

solver = GWSolver(rank=rank, epsilon=eps, threshold=thresh)
pred = solver(x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"})

assert solver.rank == rank
assert pred.rank == rank
Expand All @@ -128,23 +122,22 @@ class TestFGW:
@pytest.mark.parametrize("eps", [1e-2, 1e-1, 5e-1])
def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float], alpha: float) -> None:
thresh = 1e-2
kwargs = {"epsilon": eps, "threshold": thresh, "alpha": alpha}
kwargs["tags"] = {"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"}
xx, yy = xy

gt = gromov_wasserstein(
geom_xx=PointCloud(x, epsilon=eps),
geom_yy=PointCloud(y, epsilon=eps),
geom_xy=PointCloud(xy[0], xy[1], epsilon=eps),
geom_xy=PointCloud(xx, yy, epsilon=eps),
fused_penalty=FGWSolver._alpha_to_fused_penalty(alpha),
epsilon=eps,
threshold=thresh,
)

solver = FGWSolver(**kwargs)
solver = FGWSolver(epsilon=eps, threshold=thresh)
assert isinstance(solver.solver, GromovWasserstein)
assert solver.xy is None

pred = solver(x=x, y=y, xy=xy, **kwargs)
pred = solver(x=x, y=y, xy=xy, alpha=alpha, tags={"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"})

assert solver.rank == -1
assert pred.rank == -1
Expand All @@ -157,9 +150,6 @@ def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None:
thresh, eps = 5e-2, 1e-1
xx, yy = xy

kwargs = {"epsilon": eps, "threshold": thresh, "alpha": alpha}
kwargs["tags"] = {"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"}

gt = gromov_wasserstein(
geom_xx=PointCloud(x, epsilon=eps),
geom_yy=PointCloud(y, epsilon=eps),
Expand All @@ -168,8 +158,8 @@ def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None:
epsilon=eps,
threshold=thresh,
)
solver = FGWSolver(**kwargs)
pred = solver(x=x, y=y, xy=xy, **kwargs)
solver = FGWSolver(epsilon=eps, threshold=thresh)
pred = solver(x=x, y=y, xy=xy, alpha=alpha, tags={"x": "point_cloud", "y": "point_cloud", "xy": "point_cloud"})

assert not solver.is_low_rank
assert pred.rank == -1
Expand All @@ -181,7 +171,6 @@ def test_epsilon(
self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, xy_cost: jnp.ndarray, eps: Optional[float]
) -> None:
thresh, alpha = 5e-1, 0.66
kwargs = {"epsilon": eps, "threshold": thresh, "alpha": alpha}

problem = QuadraticProblem(
geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps),
Expand All @@ -191,13 +180,13 @@ def test_epsilon(
)
gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem)

solver = FGWSolver(**kwargs)
solver = FGWSolver(epsilon=eps, threshold=thresh)
pred = solver(
x=x_cost,
y=y_cost,
xy=xy_cost,
alpha=alpha,
tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX, "xy": Tag.COST_MATRIX},
**kwargs,
)

assert pred.rank == -1
Expand Down
3 changes: 0 additions & 3 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"threshold": "threshold",
"min_iterations": "min_iterations",
"max_iterations": "max_iterations",
"initializer": "quad_initializer",
"initializer_kwargs": "kwargs_init",
"jit": "jit",
"warm_start": "_warm_start",
"initializer": "quad_initializer",
}
Expand Down Expand Up @@ -225,7 +223,6 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"max_iterations": "max_iterations",
"initializer": "initializer",
"initializer_kwargs": "kwargs_init",
"jit": "jit",
}

lr_sinkhorn_solver_args = sinkhorn_solver_args.copy()
Expand Down

0 comments on commit cee2466

Please sign in to comment.