From a4f2899a2b36c736b1306d0fb918649a027e9de5 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 7 Jul 2024 03:58:24 +0900 Subject: [PATCH] Nits, add conjugate gradient --- jaxfg2/__init__.py | 8 ++-- jaxfg2/_factor_graph.py | 68 ++++++++++------------------- jaxfg2/_lie_group_variables.py | 4 -- jaxfg2/_solvers.py | 69 +++++++++++++++++++++++++---- jaxfg2/_sparse_matrices.py | 1 + jaxfg2/_variables.py | 80 ++++++++++++++++++++++++---------- scripts/pose_graph_g2o.py | 2 +- 7 files changed, 148 insertions(+), 84 deletions(-) diff --git a/jaxfg2/__init__.py b/jaxfg2/__init__.py index 2406f53..fb6f173 100644 --- a/jaxfg2/__init__.py +++ b/jaxfg2/__init__.py @@ -1,10 +1,12 @@ from . import utils as utils +from ._factor_graph import Factor as Factor +from ._factor_graph import StackedFactorGraph as StackedFactorGraph from ._lie_group_variables import SE2Var as SE2Var from ._lie_group_variables import SE3Var as SE3Var from ._lie_group_variables import SO2Var as SO2Var from ._lie_group_variables import SO3Var as SO3Var -from ._factor_graph import Factor as Factor -from ._factor_graph import StackedFactorGraph as StackedFactorGraph -from ._solvers import GaussNewtonSolver +from ._solvers import CholmodSolver as CholmodSolver +from ._solvers import ConjugateGradientSolver as ConjugateGradientSolver +from ._solvers import GaussNewtonSolver as GaussNewtonSolver from ._variables import Var as Var from ._variables import VarValues as VarValues diff --git a/jaxfg2/_factor_graph.py b/jaxfg2/_factor_graph.py index 7cecd0d..879de16 100644 --- a/jaxfg2/_factor_graph.py +++ b/jaxfg2/_factor_graph.py @@ -17,45 +17,12 @@ from ._variables import Var, VarTypeOrdering, VarValues, sort_and_stack_vars -@dataclasses.dataclass(frozen=True) -class TangentStorageLayout: - """Our tangent vector will be represented as a single flattened 1D vector. - How should it be laid out?""" - - counts: Mapping[type[Var], int] - start_indices: Mapping[type[Var], int] - ordering: VarTypeOrdering - - @staticmethod - def make(vars: Iterable[Var]) -> TangentStorageLayout: - counts: dict[type[Var], int] = {} - for var in vars: - var_type = type(var) - if isinstance(var.id, int) or len(var.id.shape) == 0: - counts[var_type] = counts.get(var_type, 0) + 1 - else: - assert len(var.id.shape) == 1 - counts[var_type] = counts.get(var_type, 0) + var.id.shape[0] - - i = 0 - start_indices: dict[type[Var], int] = {} - for var_type, count in counts.items(): - start_indices[var_type] = i - i += var_type.parameter_dim * count - - return TangentStorageLayout( - counts=frozendict(counts), - start_indices=frozendict(start_indices), - ordering=VarTypeOrdering(tuple(start_indices.keys())), - ) - - @jdc.pytree_dataclass class StackedFactorGraph: stacked_factors: tuple[Factor, ...] stacked_factors_var_indices: tuple[dict[type[Var], jax.Array], ...] jacobian_coords: SparseCooCoordinates - tangent_layout: jdc.Static[TangentStorageLayout] + tangent_ordering: jdc.Static[VarTypeOrdering] residual_dim: jdc.Static[int] def compute_residual_vector(self, vals: VarValues) -> jax.Array: @@ -77,9 +44,7 @@ def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix: def compute_jac_with_perturb( factor: Factor, indices_from_type: dict[type[Var], jax.Array] ) -> jax.Array: - val_subset = vals._get_subset( - indices_from_type, self.tangent_layout.ordering - ) + val_subset = vals._get_subset(indices_from_type, self.tangent_ordering) # Shape should be: (single_residual_dim, vars * tangent_dim). return ( # Getting the Jacobian of... @@ -89,7 +54,7 @@ def compute_jac_with_perturb( )( # The residual function, with respect to to some local delta. lambda tangent: factor.compute_residual( - val_subset._retract(tangent, self.tangent_layout.ordering), + val_subset._retract(tangent, self.tangent_ordering), *factor.args, ) )(jnp.zeros((val_subset._get_tangent_dim(),))) @@ -146,8 +111,7 @@ def make( ) # Record factor and variables. - factors_from_group.setdefault(group_key, []) - factors_from_group[group_key].append(factor) + factors_from_group.setdefault(group_key, []).append(factor) # Fields we want to populate. stacked_factors: list[Factor] = [] @@ -156,7 +120,23 @@ def make( # Create storage layout: this describes which parts of our tangent # vector is allocated to each variable. - tangent_layout = TangentStorageLayout.make(vars) + tangent_start_from_var_type = dict[type[Var], int]() + + vars_from_var_type = dict[type[Var], list[Var]]() + for var in vars: + vars_from_var_type.setdefault(type(var), []).append(var) + counter = 0 + for var_type, vars_one_type in vars_from_var_type.items(): + tangent_start_from_var_type[var_type] = counter + counter += var_type.tangent_dim * len(vars_one_type) + + # Create ordering helper. + tangent_ordering = VarTypeOrdering( + { + var_type: i + for i, var_type in enumerate(tangent_start_from_var_type.keys()) + } + ) # Sort variable IDs. sorted_ids_from_var_type = sort_and_stack_vars(vars) @@ -187,7 +167,7 @@ def make( subset_indices_from_type = dict[type[Var], jax.Array]() - for var_type, ids in tangent_layout.ordering.ordered_dict_items( + for var_type, ids in tangent_ordering.ordered_dict_items( stacked_factor.sorted_ids_from_var_type ): logger.info("Making initial grid") @@ -210,7 +190,7 @@ def make( logger.info("Computing tangent") # Variable index => indices into the tangent vector. tangent_start_indices = ( - tangent_layout.start_indices[var_type] + tangent_start_from_var_type[var_type] + var_indices * var_type.tangent_dim ) assert tangent_start_indices.shape == (len(group), ids.shape[-1]) @@ -257,7 +237,7 @@ def make( stacked_factors=tuple(stacked_factors), stacked_factors_var_indices=tuple(stacked_factor_var_indices), jacobian_coords=jacobian_coords_concat, - tangent_layout=tangent_layout, + tangent_ordering=tangent_ordering, residual_dim=residual_offset, ) diff --git a/jaxfg2/_lie_group_variables.py b/jaxfg2/_lie_group_variables.py index 1da4690..018cd53 100644 --- a/jaxfg2/_lie_group_variables.py +++ b/jaxfg2/_lie_group_variables.py @@ -4,7 +4,6 @@ from ._variables import Var -@jdc.pytree_dataclass class SO2Var( Var[jaxlie.SO2], default=jaxlie.SO2.identity(), @@ -14,7 +13,6 @@ class SO2Var( ... -@jdc.pytree_dataclass class SO3Var( Var[jaxlie.SO3], default=jaxlie.SO3.identity(), @@ -24,7 +22,6 @@ class SO3Var( ... -@jdc.pytree_dataclass class SE2Var( Var[jaxlie.SE2], default=jaxlie.SE2.identity(), @@ -34,7 +31,6 @@ class SE2Var( ... -@jdc.pytree_dataclass class SE3Var( Var[jaxlie.SE3], default=jaxlie.SE3.identity(), diff --git a/jaxfg2/_solvers.py b/jaxfg2/_solvers.py index 65aa87c..6860f3e 100644 --- a/jaxfg2/_solvers.py +++ b/jaxfg2/_solvers.py @@ -2,19 +2,20 @@ import abc import functools -from typing import Hashable, override +from typing import Hashable, cast, override import jax import jax.flatten_util import jax_dataclasses as jdc +import scipy import sksparse.cholmod from jax import numpy as jnp from jaxfg2.utils import jax_log -from ._factor_graph import StackedFactorGraph, TangentStorageLayout +from ._factor_graph import StackedFactorGraph from ._sparse_matrices import SparseCooMatrix -from ._variables import VarValues +from ._variables import VarTypeOrdering, VarValues # Linear solvers. @@ -63,6 +64,58 @@ def _solve_on_host( return _cholmod_analyze_cache[self_hash].solve_A(ATb) +@jdc.pytree_dataclass +class ConjugateGradientSolver: + tolerance: float = 1e-5 + inexact_step_eta: float | None = 1e-2 + """Forcing sequence parameter for inexact Newton steps. CG tolerance is set to + `eta / iteration #`. + + For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR + LEAST SQUARES, Wright & Holt 1983.""" + + def solve( + self, + A: SparseCooMatrix, + ATb: jax.Array, + lambd: float | jax.Array, + iterations: int | jax.Array, + ) -> jnp.ndarray: + assert len(A.values.shape) == 1, "A.values should be 1D" + assert len(ATb.shape) == 1, "ATb should be 1D!" + + initial_x = jnp.zeros(ATb.shape) + + # Get diagonals of ATA, for regularization + Jacobi preconditioning + ATA_diagonals = jnp.zeros_like(initial_x).at[A.coords.cols].add(A.values**2) + + # Form normal equation + def ATA_function(x: jax.Array): + ATAx = A.T @ (A @ x) + + # Scale-invariant regularization. + return ATAx + lambd * ATA_diagonals * x + + def jacobi_preconditioner(x): + return x / ATA_diagonals + + # Solve with conjugate gradient. + solution_values, _ = jax.scipy.sparse.linalg.cg( + A=ATA_function, + b=ATb, + x0=initial_x, + # https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties + maxiter=len(initial_x), + tol=cast( + float, jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)) + ) + if self.inexact_step_eta is not None + else self.tolerance, + M=jacobi_preconditioner, + ) + return solution_values + + # Nonlinear solve utils. @@ -91,7 +144,7 @@ def check_convergence( state_prev: NonlinearSolverState, cost_updated: jax.Array, tangent: jax.Array, - tangent_layout: TangentStorageLayout, + tangent_ordering: VarTypeOrdering, ATb: jax.Array, ) -> jax.Array: """Check for convergence!""" @@ -109,7 +162,7 @@ def check_convergence( jnp.max( flat_vals - jax.flatten_util.ravel_pytree( - state_prev.vals._retract(ATb, tangent_layout.ordering) + state_prev.vals._retract(ATb, tangent_ordering) )[0] ) < self.gradient_tolerance, @@ -150,7 +203,7 @@ class NonlinearSolverState: @jdc.pytree_dataclass class NonlinearSolver[TState: NonlinearSolverState]: - linear_solver: CholmodSolver = CholmodSolver() + linear_solver: CholmodSolver | ConjugateGradientSolver = CholmodSolver() verbose: jdc.Static[bool] = True """Set to `True` to enable printing.""" @@ -209,7 +262,7 @@ def _step( tangent = self.linear_solver.solve( A, ATb, lambd=0.0, iterations=state.iterations ) - vals = state.vals._retract(tangent, graph.tangent_layout.ordering) + vals = state.vals._retract(tangent, graph.tangent_ordering) if self.verbose: jax_log( @@ -226,7 +279,7 @@ def _step( state, cost_updated=state_next.cost, tangent=tangent, - tangent_layout=graph.tangent_layout, + tangent_ordering=graph.tangent_ordering, ATb=ATb, ) return state_next diff --git a/jaxfg2/_sparse_matrices.py b/jaxfg2/_sparse_matrices.py index d420e08..291c7d3 100644 --- a/jaxfg2/_sparse_matrices.py +++ b/jaxfg2/_sparse_matrices.py @@ -58,6 +58,7 @@ def T(self): def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix: """Convert to a sparse scipy matrix.""" + assert len(self.values.shape) == 1 return scipy.sparse.coo_matrix( (self.values, (self.coords.rows, self.coords.cols)), shape=self.shape ) diff --git a/jaxfg2/_variables.py b/jaxfg2/_variables.py index b169af5..bd0d077 100644 --- a/jaxfg2/_variables.py +++ b/jaxfg2/_variables.py @@ -1,62 +1,94 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, ClassVar, Iterable, Self, overload, override +from typing import Any, Callable, ClassVar, Iterable, Self, cast, overload, override import jax import jax_dataclasses as jdc import numpy as onp +from jax import flatten_util from jax import numpy as jnp @dataclass(frozen=True) class VarTypeOrdering: - var_types: tuple[type[Var], ...] + order_from_type: dict[type[Var], int] def ordered_dict_items[T]( self, var_type_mapping: dict[type[Var], T], ) -> list[tuple[type[Var], T]]: return sorted( - var_type_mapping.items(), key=lambda x: self.var_types.index(x[0]) + var_type_mapping.items(), key=lambda x: self.order_from_type[x[0]] ) +class _Var[T]: + # Subclass is a hack to avoid ClassVar[] annotations, which prevent us from using generics, + # while also not adding these annotations as fields to the Var() dataclass. + # + # https://github.com/python/typing/discussions/1424 + default: T + """Default value for this variable.""" + tangent_dim: int + """Dimension of the tangent space.""" + retract_fn: Callable[[T, jax.Array], T] + """Retraction function for the manifold. None for Euclidean space.""" + + @jdc.pytree_dataclass -class Var[T]: +class Var[T](_Var[T]): """A symbolic representation of an optimization variable.""" id: int | jax.Array - # Class properties. - # type ignores are for generics: https://github.com/python/typing/discussions/1424 - default: ClassVar[T] # type: ignore - """Default value for this variable.""" - parameter_dim: ClassVar[int] - """Number of parameters in this variable type.""" - tangent_dim: ClassVar[int] - """Dimension of the tangent space.""" - retract_fn: ClassVar[Callable[[T, jax.Array], T]] # type: ignore - """Retraction function for the manifold. None for Euclidean space.""" + @overload + def __init_subclass__( + cls, + default: T, + retract_fn: None = None, + tangent_dim: None = None, + ) -> None: + ... + @overload def __init_subclass__( cls, default: T, - tangent_dim: int | None, retract_fn: Callable[[T, jax.Array], T], + tangent_dim: int, + ) -> None: + ... + + def __init_subclass__( + cls, + default: T, + retract_fn: Callable[[T, jax.Array], T] | None = None, + tangent_dim: int | None = None, ) -> None: cls.default = default - cls.parameter_dim = int( - sum([onp.prod(leaf.size) for leaf in jax.tree.leaves(default)]) - ) - cls.tangent_dim = tangent_dim if tangent_dim is not None else cls.parameter_dim - cls.retract_fn = retract_fn # type: ignore + if retract_fn is not None: + assert tangent_dim is not None + cls.tangent_dim = tangent_dim + cls.retract_fn = retract_fn + else: + assert tangent_dim is None + parameter_dim = int( + sum([onp.prod(leaf.size) for leaf in jax.tree.leaves(default)]) + ) + cls.tangent_dim = parameter_dim + cls.retract_fn = cls._euclidean_retract + super().__init_subclass__() - @classmethod - def allocate(cls, count: int, start_id: int = 0) -> tuple[Self, ...]: - """Helper for allocating a sequence of variables.""" - return tuple(cls(i) for i in range(start_id, start_id + count)) + # Subclasses need to be registered as PyTrees. + jdc.pytree_dataclass(cls) + + @staticmethod + def _euclidean_retract(pytree: T, delta: jax.Array) -> T: + # Euclidean retraction. + flat, unravel = flatten_util.ravel_pytree(pytree) + return cast(T, jax.tree_map(jnp.add, pytree, unravel(delta))) @jdc.pytree_dataclass diff --git a/scripts/pose_graph_g2o.py b/scripts/pose_graph_g2o.py index a000ef6..c62c705 100755 --- a/scripts/pose_graph_g2o.py +++ b/scripts/pose_graph_g2o.py @@ -32,7 +32,7 @@ def main( jax.block_until_ready(graph) with jaxfg2.utils.stopwatch("Making solver"): - solver = jaxfg2.GaussNewtonSolver(verbose=True) + solver = jaxfg2.GaussNewtonSolver(verbose=True, linear_solver=jaxfg2.ConjugateGradientSolver()) initial_vals = jaxfg2.VarValues.make(g2o.pose_vars, g2o.initial_poses) with jaxfg2.utils.stopwatch("Running solve"):