diff --git a/jaxfg2/_factor_graph.py b/jaxfg2/_factor_graph.py index fc1e1ac..8c2fb70 100644 --- a/jaxfg2/_factor_graph.py +++ b/jaxfg2/_factor_graph.py @@ -5,7 +5,7 @@ import functools import inspect import linecache -from typing import Callable, Hashable, Iterable, Mapping, Self, cast +from typing import Callable, Hashable, Iterable, Literal, Mapping, Self, cast import jax import jax_dataclasses as jdc @@ -14,20 +14,25 @@ from jax import numpy as jnp from loguru import logger -from ._sparse_matrices import SparseCooCoordinates, SparseCooMatrix +from ._sparse_matrices import ( + SparseCooCoordinates, + SparseCooMatrix, + SparseCsrCoordinates, +) from ._variables import Var, VarTypeOrdering, VarValues, sort_and_stack_vars @jdc.pytree_dataclass class StackedFactorGraph: - stacked_factors: tuple[Factor, ...] - jacobian_coords: SparseCooCoordinates - tangent_ordering: jdc.Static[VarTypeOrdering] - residual_dim: jdc.Static[int] + _stacked_factors: tuple[Factor, ...] + _jacobian_coords_coo: SparseCooCoordinates + _jacobian_coords_csr: SparseCsrCoordinates + _tangent_ordering: jdc.Static[VarTypeOrdering] + _residual_dim: jdc.Static[int] def compute_residual_vector(self, vals: VarValues) -> jax.Array: residual_slices = list[jax.Array]() - for stacked_factor in self.stacked_factors: + for stacked_factor in self._stacked_factors: stacked_residual_slice = jax.vmap( lambda args: stacked_factor.compute_residual(vals, *args) )(stacked_factor.args) @@ -35,9 +40,9 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array: residual_slices.append(stacked_residual_slice.reshape((-1,))) return jnp.concatenate(residual_slices, axis=0) - def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix: + def _compute_jacobian_values(self, vals: VarValues) -> jax.Array: jac_vals = [] - for factor in self.stacked_factors: + for factor in self._stacked_factors: # Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim). def compute_jac_with_perturb(factor: Factor) -> jax.Array: val_subset = vals._get_subset( @@ -45,18 +50,22 @@ def compute_jac_with_perturb(factor: Factor) -> jax.Array: var_type: jnp.searchsorted(vals.ids_from_type[var_type], ids) for var_type, ids in factor.sorted_ids_from_var_type.items() }, - self.tangent_ordering, + self._tangent_ordering, ) # Shape should be: (single_residual_dim, vars * tangent_dim). return ( # Getting the Jacobian of... jax.jacrev - if factor.residual_dim < val_subset._get_tangent_dim() + if ( + factor.jacobian_mode == "auto" + and factor.residual_dim < val_subset._get_tangent_dim() + or factor.jacobian_mode == "reverse" + ) else jax.jacfwd )( # The residual function, with respect to to some local delta. lambda tangent: factor.compute_residual( - val_subset._retract(tangent, self.tangent_ordering), + val_subset._retract(tangent, self._tangent_ordering), *factor.args, ) )(jnp.zeros((val_subset._get_tangent_dim(),))) @@ -70,12 +79,8 @@ def compute_jac_with_perturb(factor: Factor) -> jax.Array: ) jac_vals.append(stacked_jac.flatten()) jac_vals = jnp.concatenate(jac_vals, axis=0) - assert jac_vals.shape == (self.jacobian_coords.rows.shape[0],) - return SparseCooMatrix( - values=jac_vals, - coords=self.jacobian_coords, - shape=(self.residual_dim, vals._get_tangent_dim()), - ) + assert jac_vals.shape == (self._jacobian_coords_coo.rows.shape[0],) + return jac_vals @staticmethod def make( @@ -117,7 +122,7 @@ def make( # Fields we want to populate. stacked_factors: list[Factor] = [] - jacobian_coords: list[SparseCooCoordinates] = [] + jacobian_coords: list[tuple[jax.Array, jax.Array]] = [] # Create storage layout: this describes which parts of our tangent # vector is allocated to each variable. @@ -126,10 +131,10 @@ def make( vars_from_var_type = dict[type[Var], list[Var]]() for var in vars: vars_from_var_type.setdefault(type(var), []).append(var) - counter = 0 + tangent_offset = 0 for var_type, vars_one_type in vars_from_var_type.items(): - tangent_start_from_var_type[var_type] = counter - counter += var_type.tangent_dim * len(vars_one_type) + tangent_start_from_var_type[var_type] = tangent_offset + tangent_offset += var_type.tangent_dim * len(vars_one_type) # Create ordering helper. tangent_ordering = VarTypeOrdering( @@ -180,18 +185,31 @@ def make( jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim ) rows = rows + residual_offset - jacobian_coords.append(SparseCooCoordinates(rows.flatten(), cols.flatten())) + jacobian_coords.append((rows.flatten(), cols.flatten())) residual_offset += stacked_factor.residual_dim * len(group) - jacobian_coords_concat: SparseCooCoordinates = jax.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords + jacobian_coords_concat: SparseCooCoordinates = SparseCooCoordinates( + *jax.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords + ), + shape=(residual_offset, tangent_offset), ) logger.info("Done!") return StackedFactorGraph( - stacked_factors=tuple(stacked_factors), - jacobian_coords=jacobian_coords_concat, - tangent_ordering=tangent_ordering, - residual_dim=residual_offset, + _stacked_factors=tuple(stacked_factors), + _jacobian_coords_coo=jacobian_coords_concat, + _jacobian_coords_csr=SparseCsrCoordinates( + indices=jacobian_coords_concat.cols, + indptr=cast( + jax.Array, + jnp.searchsorted( + jacobian_coords_concat.rows, jnp.arange(residual_offset + 1) + ), + ), + shape=(residual_offset, tangent_offset), + ), + _tangent_ordering=tangent_ordering, + _residual_dim=residual_offset, ) @@ -208,11 +226,13 @@ class Factor[*Args]: num_vars: jdc.Static[int] sorted_ids_from_var_type: dict[type[Var], jax.Array] residual_dim: jdc.Static[int] + jacobian_mode: jdc.Static[Literal["auto", "forward", "reverse"]] @staticmethod def make[*Args_]( compute_residual: Callable[[VarValues, *Args_], jax.Array], args: tuple[*Args_], + jacobian_mode: Literal["auto", "forward", "reverse"] = "auto", ) -> Factor[*Args_]: """Construct a factor for our factor graph.""" # If we see two functions with the same signature, we always use the @@ -223,13 +243,14 @@ def make[*Args_]( Factor._get_function_signature(compute_residual), compute_residual ), ) - return Factor._make_impl(compute_residual, args) + return Factor._make_impl(compute_residual, args, jacobian_mode) @staticmethod @jdc.jit def _make_impl[*Args_]( compute_residual: jdc.Static[Callable[[VarValues, *Args_], jax.Array]], args: tuple[*Args_], + jacobian_mode: jdc.Static[Literal["auto", "forward", "reverse"]], ) -> Factor[*Args_]: """Construct a factor for our factor graph.""" @@ -267,6 +288,7 @@ def _make_impl[*Args_]( tuple(cast(Var, args[i]) for i in variable_indices) ), residual_dim=_residual_dim_cache[residual_dim_cache_key], + jacobian_mode=jacobian_mode, ) @staticmethod diff --git a/jaxfg2/_solvers.py b/jaxfg2/_solvers.py index 6860f3e..12ecd03 100644 --- a/jaxfg2/_solvers.py +++ b/jaxfg2/_solvers.py @@ -8,13 +8,14 @@ import jax.flatten_util import jax_dataclasses as jdc import scipy +import scipy.sparse import sksparse.cholmod from jax import numpy as jnp from jaxfg2.utils import jax_log from ._factor_graph import StackedFactorGraph -from ._sparse_matrices import SparseCooMatrix +from ._sparse_matrices import SparseCooMatrix, SparseCsrMatrix from ._variables import VarTypeOrdering, VarValues # Linear solvers. @@ -25,13 +26,8 @@ @jdc.pytree_dataclass class CholmodSolver: def solve( - self, - A: SparseCooMatrix, - ATb: jax.Array, - lambd: float | jax.Array, - iterations: int | jax.Array, + self, A: SparseCsrMatrix, ATb: jax.Array, lambd: float | jax.Array ) -> jax.Array: - del iterations return jax.pure_callback( self._solve_on_host, ATb, # Result shape/dtype. @@ -43,14 +39,16 @@ def solve( def _solve_on_host( self, - A: SparseCooMatrix, + A: SparseCsrMatrix, ATb: jax.Array, lambd: float | jax.Array, ) -> jax.Array: - A_T = A.T - A_T_scipy = A_T.as_scipy_coo_matrix().tocsc(copy=False) + # Matrix is transposed when we convert CSR to CSC. + A_T_scipy = scipy.sparse.csc_matrix( + (A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1] + ) - # Cache sparsity pattern analysis + # Cache sparsity pattern analysis. self_hash = object.__hash__(self) if self_hash not in _cholmod_analyze_cache: _cholmod_analyze_cache[self_hash] = sksparse.cholmod.analyze_AAt(A_T_scipy) @@ -107,7 +105,8 @@ def jacobi_preconditioner(x): # https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties maxiter=len(initial_x), tol=cast( - float, jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)) + float, + jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)), ) if self.inexact_step_eta is not None else self.tolerance, @@ -256,14 +255,21 @@ def _initialize_state( def _step( self, graph: StackedFactorGraph, state: NonlinearSolverState ) -> NonlinearSolverState: - A = graph._compute_jacobian_wrt_tangent(state.vals) - ATb = -(A.T @ state.residual_vector) + jac_values = graph._compute_jacobian_values(state.vals) + A_coo = SparseCooMatrix(jac_values, graph._jacobian_coords_coo) + ATb = -(A_coo.T @ state.residual_vector) - tangent = self.linear_solver.solve( - A, ATb, lambd=0.0, iterations=state.iterations - ) - vals = state.vals._retract(tangent, graph.tangent_ordering) + if isinstance(self.linear_solver, ConjugateGradientSolver): + tangent = self.linear_solver.solve( + A_coo, ATb, lambd=0.0, iterations=state.iterations + ) + elif isinstance(self.linear_solver, CholmodSolver): + A_csr = SparseCsrMatrix(jac_values, graph._jacobian_coords_csr) + tangent = self.linear_solver.solve(A_csr, ATb, lambd=0.0) + else: + assert False + vals = state.vals._retract(tangent, graph._tangent_ordering) if self.verbose: jax_log( "Gauss-Newton step #{i}: cost={cost:.4f}", @@ -279,7 +285,7 @@ def _step( state, cost_updated=state_next.cost, tangent=tangent, - tangent_ordering=graph.tangent_ordering, + tangent_ordering=graph._tangent_ordering, ATb=ATb, ) return state_next diff --git a/jaxfg2/_sparse_matrices.py b/jaxfg2/_sparse_matrices.py index 6b21fe6..24a3208 100644 --- a/jaxfg2/_sparse_matrices.py +++ b/jaxfg2/_sparse_matrices.py @@ -1,27 +1,24 @@ import jax import jax_dataclasses as jdc -import scipy.sparse from jax import numpy as jnp @jdc.pytree_dataclass class SparseCsrCoordinates: - row_starts: jax.Array - """Index into `cols` for the start of each row.""" - cols: jax.Array + indices: jax.Array """Column indices of non-zero entries. Shape should be `(*, N)`.""" + indptr: jax.Array + """Index of start to each row. Shape should be `(*, num_rows)`.""" + shape: jdc.Static[tuple[int, int]] @jdc.pytree_dataclass class SparseCsrMatrix: - """Sparse matrix in COO form.""" + """Data structure for sparse CSR matrices.""" values: jax.Array """Non-zero matrix values. Shape should be `(*, N)`.""" coords: SparseCsrCoordinates - """Row and column indices of non-zero entries. Shapes should be `(*, N)`.""" - shape: jdc.Static[tuple[int, int]] - """Shape of matrix.""" @jdc.pytree_dataclass @@ -30,6 +27,8 @@ class SparseCooCoordinates: """Row indices of non-zero entries. Shape should be `(*, N)`.""" cols: jax.Array """Column indices of non-zero entries. Shape should be `(*, N)`.""" + shape: jdc.Static[tuple[int, int]] + """Shape of matrix.""" @jdc.pytree_dataclass @@ -40,25 +39,22 @@ class SparseCooMatrix: """Non-zero matrix values. Shape should be `(*, N)`.""" coords: SparseCooCoordinates """Row and column indices of non-zero entries. Shapes should be `(*, N)`.""" - shape: jdc.Static[tuple[int, int]] - """Shape of matrix.""" def __matmul__(self, other: jax.Array): """Compute `Ax`, where `x` is a 1D vector.""" assert other.shape == ( - self.shape[1], + self.coords.shape[1], ), "Inner product only supported for 1D vectors!" return ( - jnp.zeros(self.shape[0], dtype=other.dtype) + jnp.zeros(self.coords.shape[0], dtype=other.dtype) .at[self.coords.rows] .add(self.values * other[self.coords.cols]) ) def as_dense(self) -> jnp.ndarray: """Convert to a dense JAX array.""" - # TODO: untested return ( - jnp.zeros(self.shape) + jnp.zeros(self.coords.shape) .at[self.coords.rows, self.coords.cols] .set(self.values) ) @@ -66,19 +62,12 @@ def as_dense(self) -> jnp.ndarray: @property def T(self): """Return transpose of our sparse matrix.""" - h, w = self.shape + h, w = self.coords.shape return SparseCooMatrix( values=self.values, coords=SparseCooCoordinates( rows=self.coords.cols, cols=self.coords.rows, + shape=(w, h), ), - shape=(w, h), - ) - - def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix: - """Convert to a sparse scipy matrix.""" - assert len(self.values.shape) == 1 - return scipy.sparse.coo_matrix( - (self.values, (self.coords.rows, self.coords.cols)), shape=self.shape ) diff --git a/scripts/_g2o_utils.py b/scripts/_g2o_utils.py index 93ab08a..c81a8a2 100644 --- a/scripts/_g2o_utils.py +++ b/scripts/_g2o_utils.py @@ -2,13 +2,12 @@ import dataclasses import pathlib -from typing import Any, cast +from typing import cast import jax import jaxfg2 import jaxlie import numpy as onp -from jax import numpy as jnp from tqdm.auto import tqdm @@ -85,6 +84,7 @@ def parse_g2o(path: pathlib.Path, pose_count_limit: int = 100000) -> G2OData: between, cast(jax.Array, sqrt_precision_matrix), ), + jacobian_mode="forward", ) factors.append(factor) @@ -127,13 +127,26 @@ def parse_g2o(path: pathlib.Path, pose_count_limit: int = 100000) -> G2OData: sqrt_precision_matrix = onp.linalg.cholesky(precision_matrix).T factor = jaxfg2.Factor.make( - between_residual, + # Passing in arrays like sqrt_precision_matrix as input makes + # it possible for jaxfg vectorize factors. + ( + lambda values, + T_world_a, + T_world_b, + between, + sqrt_precision_matrix: sqrt_precision_matrix + @ ( + (values[T_world_a].inverse() @ values[T_world_b]).inverse() + @ between + ).log() + ), args=( pose_variables[before_index], pose_variables[after_index], between, cast(jax.Array, sqrt_precision_matrix), ), + jacobian_mode="forward", ) factors.append(factor) else: @@ -145,6 +158,7 @@ def parse_g2o(path: pathlib.Path, pose_count_limit: int = 100000) -> G2OData: var_values[start_pose].inverse() @ initial_poses[0] ).log(), args=(pose_variables[0],), + jacobian_mode="reverse", ) factors.append(factor) diff --git a/scripts/pose_graph_g2o.py b/scripts/pose_graph_g2o.py index b888576..1ea3448 100755 --- a/scripts/pose_graph_g2o.py +++ b/scripts/pose_graph_g2o.py @@ -5,14 +5,11 @@ python pose_graph_g2o.py --help """ -import dataclasses -import enum + import pathlib -from typing import Dict, Optional import jax import jaxfg2 -import matplotlib.pyplot as plt import tyro import _g2o_utils @@ -32,7 +29,9 @@ def main( jax.block_until_ready(graph) with jaxfg2.utils.stopwatch("Making solver"): - solver = jaxfg2.GaussNewtonSolver(verbose=True) #, linear_solver=jaxfg2.ConjugateGradientSolver()) + solver = jaxfg2.GaussNewtonSolver( + verbose=True + ) # , linear_solver=jaxfg2.ConjugateGradientSolver()) initial_vals = jaxfg2.VarValues.make(g2o.pose_vars, g2o.initial_poses) with jaxfg2.utils.stopwatch("Running solve"): diff --git a/scripts/pose_graph_simple.py b/scripts/pose_graph_simple.py index 8f92ef8..c5be39b 100644 --- a/scripts/pose_graph_simple.py +++ b/scripts/pose_graph_simple.py @@ -1,19 +1,17 @@ """Simple pose graph example with two pose variables and three factors: - ┌────────┐ ┌────────┐ - │ Pose 0 ├───Between───┤ Pose 1 │ - └───┬────┘ └────┬───┘ - │ │ - │ │ - Prior Prior +┌────────┐ ┌────────┐ +│ Pose 0 ├───Between───┤ Pose 1 │ +└───┬────┘ └────┬───┘ + │ │ + │ │ + Prior Prior """ -from typing import List import jaxfg2 import jaxlie -from jax import numpy as jnp # Create variables: each variable object represents something that we want to solve for. #