Skip to content

Commit

Permalink
Fix conjugate gradient solver
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 11, 2024
1 parent 3448bb5 commit 758bb77
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
4 changes: 2 additions & 2 deletions examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def main(
initial_vals = jaxls.VarValues.make(g2o.pose_vars, g2o.initial_poses)

with jaxls.utils.stopwatch("Running solve"):
solution_vals = graph.solve(initial_vals, trust_region=None)
solution_vals = graph.solve(initial_vals, trust_region=None, linear_solver=jaxls.ConjugateGradientLinearSolver())

with jaxls.utils.stopwatch("Running solve (again)"):
solution_vals = graph.solve(initial_vals, trust_region=None)
solution_vals = graph.solve(initial_vals, trust_region=None, linear_solver=jaxls.ConjugateGradientLinearSolver())

# Plot
plt.figure()
Expand Down
18 changes: 6 additions & 12 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,23 @@ def _solve(
ATb: jax.Array,
lambd: float | jax.Array,
iterations: int | jax.Array,
) -> jnp.ndarray:
) -> jax.Array:
assert len(ATb.shape) == 1, "ATb should be 1D!"

initial_x = jnp.zeros(ATb.shape)

# Get diagonals of ATA, for regularization + Jacobi preconditioning
# Get diagonals of ATA for preconditioning.
ATA_diagonals = (
jnp.zeros_like(initial_x).at[A_coo.indices[1]].add(A_coo.data**2)
jnp.zeros_like(initial_x).at[A_coo.indices[:, 1]].add(A_coo.data**2)
)

# Form normal equation
# Form normal equation.
def ATA_function(x: jax.Array):
ATAx = A_coo.T @ (A_coo @ x)
ATAx = A_coo.transpose() @ (A_coo @ x)
# We could also use (lambd * ATA_diagonals * x) for
# scale-invariance. But this is hard to match with CHOLMOD.
return ATAx + lambd * x

def jacobi_preconditioner(x):
return x / ATA_diagonals

# Solve with conjugate gradient.
solution_values, _ = jax.scipy.sparse.linalg.cg(
A=ATA_function,
Expand All @@ -125,14 +122,11 @@ def jacobi_preconditioner(x):
)
if self.inexact_step_eta is not None
else self.tolerance,
M=jacobi_preconditioner,
M=lambda x: x / ATA_diagonals, # Jacobi preconditioner.
)
return solution_values


# Nonlinear solve utils.


# Nonlinear solvers.


Expand Down

0 comments on commit 758bb77

Please sign in to comment.