Skip to content

Commit

Permalink
Nits, add conjugate gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 6, 2024
1 parent 186f979 commit a4f2899
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 84 deletions.
8 changes: 5 additions & 3 deletions jaxfg2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from . import utils as utils
from ._factor_graph import Factor as Factor
from ._factor_graph import StackedFactorGraph as StackedFactorGraph
from ._lie_group_variables import SE2Var as SE2Var
from ._lie_group_variables import SE3Var as SE3Var
from ._lie_group_variables import SO2Var as SO2Var
from ._lie_group_variables import SO3Var as SO3Var
from ._factor_graph import Factor as Factor
from ._factor_graph import StackedFactorGraph as StackedFactorGraph
from ._solvers import GaussNewtonSolver
from ._solvers import CholmodSolver as CholmodSolver
from ._solvers import ConjugateGradientSolver as ConjugateGradientSolver
from ._solvers import GaussNewtonSolver as GaussNewtonSolver
from ._variables import Var as Var
from ._variables import VarValues as VarValues
68 changes: 24 additions & 44 deletions jaxfg2/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,45 +17,12 @@
from ._variables import Var, VarTypeOrdering, VarValues, sort_and_stack_vars


@dataclasses.dataclass(frozen=True)
class TangentStorageLayout:
"""Our tangent vector will be represented as a single flattened 1D vector.
How should it be laid out?"""

counts: Mapping[type[Var], int]
start_indices: Mapping[type[Var], int]
ordering: VarTypeOrdering

@staticmethod
def make(vars: Iterable[Var]) -> TangentStorageLayout:
counts: dict[type[Var], int] = {}
for var in vars:
var_type = type(var)
if isinstance(var.id, int) or len(var.id.shape) == 0:
counts[var_type] = counts.get(var_type, 0) + 1
else:
assert len(var.id.shape) == 1
counts[var_type] = counts.get(var_type, 0) + var.id.shape[0]

i = 0
start_indices: dict[type[Var], int] = {}
for var_type, count in counts.items():
start_indices[var_type] = i
i += var_type.parameter_dim * count

return TangentStorageLayout(
counts=frozendict(counts),
start_indices=frozendict(start_indices),
ordering=VarTypeOrdering(tuple(start_indices.keys())),
)


@jdc.pytree_dataclass
class StackedFactorGraph:
stacked_factors: tuple[Factor, ...]
stacked_factors_var_indices: tuple[dict[type[Var], jax.Array], ...]
jacobian_coords: SparseCooCoordinates
tangent_layout: jdc.Static[TangentStorageLayout]
tangent_ordering: jdc.Static[VarTypeOrdering]
residual_dim: jdc.Static[int]

def compute_residual_vector(self, vals: VarValues) -> jax.Array:
Expand All @@ -77,9 +44,7 @@ def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix:
def compute_jac_with_perturb(
factor: Factor, indices_from_type: dict[type[Var], jax.Array]
) -> jax.Array:
val_subset = vals._get_subset(
indices_from_type, self.tangent_layout.ordering
)
val_subset = vals._get_subset(indices_from_type, self.tangent_ordering)
# Shape should be: (single_residual_dim, vars * tangent_dim).
return (
# Getting the Jacobian of...
Expand All @@ -89,7 +54,7 @@ def compute_jac_with_perturb(
)(
# The residual function, with respect to to some local delta.
lambda tangent: factor.compute_residual(
val_subset._retract(tangent, self.tangent_layout.ordering),
val_subset._retract(tangent, self.tangent_ordering),
*factor.args,
)
)(jnp.zeros((val_subset._get_tangent_dim(),)))
Expand Down Expand Up @@ -146,8 +111,7 @@ def make(
)

# Record factor and variables.
factors_from_group.setdefault(group_key, [])
factors_from_group[group_key].append(factor)
factors_from_group.setdefault(group_key, []).append(factor)

# Fields we want to populate.
stacked_factors: list[Factor] = []
Expand All @@ -156,7 +120,23 @@ def make(

# Create storage layout: this describes which parts of our tangent
# vector is allocated to each variable.
tangent_layout = TangentStorageLayout.make(vars)
tangent_start_from_var_type = dict[type[Var], int]()

vars_from_var_type = dict[type[Var], list[Var]]()
for var in vars:
vars_from_var_type.setdefault(type(var), []).append(var)
counter = 0
for var_type, vars_one_type in vars_from_var_type.items():
tangent_start_from_var_type[var_type] = counter
counter += var_type.tangent_dim * len(vars_one_type)

# Create ordering helper.
tangent_ordering = VarTypeOrdering(
{
var_type: i
for i, var_type in enumerate(tangent_start_from_var_type.keys())
}
)

# Sort variable IDs.
sorted_ids_from_var_type = sort_and_stack_vars(vars)
Expand Down Expand Up @@ -187,7 +167,7 @@ def make(

subset_indices_from_type = dict[type[Var], jax.Array]()

for var_type, ids in tangent_layout.ordering.ordered_dict_items(
for var_type, ids in tangent_ordering.ordered_dict_items(
stacked_factor.sorted_ids_from_var_type
):
logger.info("Making initial grid")
Expand All @@ -210,7 +190,7 @@ def make(
logger.info("Computing tangent")
# Variable index => indices into the tangent vector.
tangent_start_indices = (
tangent_layout.start_indices[var_type]
tangent_start_from_var_type[var_type]
+ var_indices * var_type.tangent_dim
)
assert tangent_start_indices.shape == (len(group), ids.shape[-1])
Expand Down Expand Up @@ -257,7 +237,7 @@ def make(
stacked_factors=tuple(stacked_factors),
stacked_factors_var_indices=tuple(stacked_factor_var_indices),
jacobian_coords=jacobian_coords_concat,
tangent_layout=tangent_layout,
tangent_ordering=tangent_ordering,
residual_dim=residual_offset,
)

Expand Down
4 changes: 0 additions & 4 deletions jaxfg2/_lie_group_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ._variables import Var


@jdc.pytree_dataclass
class SO2Var(
Var[jaxlie.SO2],
default=jaxlie.SO2.identity(),
Expand All @@ -14,7 +13,6 @@ class SO2Var(
...


@jdc.pytree_dataclass
class SO3Var(
Var[jaxlie.SO3],
default=jaxlie.SO3.identity(),
Expand All @@ -24,7 +22,6 @@ class SO3Var(
...


@jdc.pytree_dataclass
class SE2Var(
Var[jaxlie.SE2],
default=jaxlie.SE2.identity(),
Expand All @@ -34,7 +31,6 @@ class SE2Var(
...


@jdc.pytree_dataclass
class SE3Var(
Var[jaxlie.SE3],
default=jaxlie.SE3.identity(),
Expand Down
69 changes: 61 additions & 8 deletions jaxfg2/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

import abc
import functools
from typing import Hashable, override
from typing import Hashable, cast, override

import jax
import jax.flatten_util
import jax_dataclasses as jdc
import scipy
import sksparse.cholmod
from jax import numpy as jnp

from jaxfg2.utils import jax_log

from ._factor_graph import StackedFactorGraph, TangentStorageLayout
from ._factor_graph import StackedFactorGraph
from ._sparse_matrices import SparseCooMatrix
from ._variables import VarValues
from ._variables import VarTypeOrdering, VarValues

# Linear solvers.

Expand Down Expand Up @@ -63,6 +64,58 @@ def _solve_on_host(
return _cholmod_analyze_cache[self_hash].solve_A(ATb)


@jdc.pytree_dataclass
class ConjugateGradientSolver:
tolerance: float = 1e-5
inexact_step_eta: float | None = 1e-2
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.
For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR
LEAST SQUARES, Wright & Holt 1983."""

def solve(
self,
A: SparseCooMatrix,
ATb: jax.Array,
lambd: float | jax.Array,
iterations: int | jax.Array,
) -> jnp.ndarray:
assert len(A.values.shape) == 1, "A.values should be 1D"
assert len(ATb.shape) == 1, "ATb should be 1D!"

initial_x = jnp.zeros(ATb.shape)

# Get diagonals of ATA, for regularization + Jacobi preconditioning
ATA_diagonals = jnp.zeros_like(initial_x).at[A.coords.cols].add(A.values**2)

# Form normal equation
def ATA_function(x: jax.Array):
ATAx = A.T @ (A @ x)

# Scale-invariant regularization.
return ATAx + lambd * ATA_diagonals * x

def jacobi_preconditioner(x):
return x / ATA_diagonals

# Solve with conjugate gradient.
solution_values, _ = jax.scipy.sparse.linalg.cg(
A=ATA_function,
b=ATb,
x0=initial_x,
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties
maxiter=len(initial_x),
tol=cast(
float, jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1))
)
if self.inexact_step_eta is not None
else self.tolerance,
M=jacobi_preconditioner,
)
return solution_values


# Nonlinear solve utils.


Expand Down Expand Up @@ -91,7 +144,7 @@ def check_convergence(
state_prev: NonlinearSolverState,
cost_updated: jax.Array,
tangent: jax.Array,
tangent_layout: TangentStorageLayout,
tangent_ordering: VarTypeOrdering,
ATb: jax.Array,
) -> jax.Array:
"""Check for convergence!"""
Expand All @@ -109,7 +162,7 @@ def check_convergence(
jnp.max(
flat_vals
- jax.flatten_util.ravel_pytree(
state_prev.vals._retract(ATb, tangent_layout.ordering)
state_prev.vals._retract(ATb, tangent_ordering)
)[0]
)
< self.gradient_tolerance,
Expand Down Expand Up @@ -150,7 +203,7 @@ class NonlinearSolverState:

@jdc.pytree_dataclass
class NonlinearSolver[TState: NonlinearSolverState]:
linear_solver: CholmodSolver = CholmodSolver()
linear_solver: CholmodSolver | ConjugateGradientSolver = CholmodSolver()
verbose: jdc.Static[bool] = True
"""Set to `True` to enable printing."""

Expand Down Expand Up @@ -209,7 +262,7 @@ def _step(
tangent = self.linear_solver.solve(
A, ATb, lambd=0.0, iterations=state.iterations
)
vals = state.vals._retract(tangent, graph.tangent_layout.ordering)
vals = state.vals._retract(tangent, graph.tangent_ordering)

if self.verbose:
jax_log(
Expand All @@ -226,7 +279,7 @@ def _step(
state,
cost_updated=state_next.cost,
tangent=tangent,
tangent_layout=graph.tangent_layout,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
)
return state_next
1 change: 1 addition & 0 deletions jaxfg2/_sparse_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def T(self):

def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix:
"""Convert to a sparse scipy matrix."""
assert len(self.values.shape) == 1
return scipy.sparse.coo_matrix(
(self.values, (self.coords.rows, self.coords.cols)), shape=self.shape
)
Loading

0 comments on commit a4f2899

Please sign in to comment.