From e066104e0a355b88a918f6b52ae7657795a725d0 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 7 Jul 2024 15:53:39 +0900 Subject: [PATCH] Shorter + correct Jacobian index calculation --- jaxfg2/_factor_graph.py | 135 ++++++++++++++++--------------------- jaxfg2/_sparse_matrices.py | 20 ++++++ scripts/pose_graph_g2o.py | 2 +- 3 files changed, 80 insertions(+), 77 deletions(-) diff --git a/jaxfg2/_factor_graph.py b/jaxfg2/_factor_graph.py index 879de16..fc1e1ac 100644 --- a/jaxfg2/_factor_graph.py +++ b/jaxfg2/_factor_graph.py @@ -2,6 +2,7 @@ import dataclasses import dis +import functools import inspect import linecache from typing import Callable, Hashable, Iterable, Mapping, Self, cast @@ -20,7 +21,6 @@ @jdc.pytree_dataclass class StackedFactorGraph: stacked_factors: tuple[Factor, ...] - stacked_factors_var_indices: tuple[dict[type[Var], jax.Array], ...] jacobian_coords: SparseCooCoordinates tangent_ordering: jdc.Static[VarTypeOrdering] residual_dim: jdc.Static[int] @@ -37,14 +37,16 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array: def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix: jac_vals = [] - for factor, subset_indices in zip( - self.stacked_factors, self.stacked_factors_var_indices - ): + for factor in self.stacked_factors: # Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim). - def compute_jac_with_perturb( - factor: Factor, indices_from_type: dict[type[Var], jax.Array] - ) -> jax.Array: - val_subset = vals._get_subset(indices_from_type, self.tangent_ordering) + def compute_jac_with_perturb(factor: Factor) -> jax.Array: + val_subset = vals._get_subset( + { + var_type: jnp.searchsorted(vals.ids_from_type[var_type], ids) + for var_type, ids in factor.sorted_ids_from_var_type.items() + }, + self.tangent_ordering, + ) # Shape should be: (single_residual_dim, vars * tangent_dim). return ( # Getting the Jacobian of... @@ -59,7 +61,7 @@ def compute_jac_with_perturb( ) )(jnp.zeros((val_subset._get_tangent_dim(),))) - stacked_jac = jax.vmap(compute_jac_with_perturb)(factor, subset_indices) + stacked_jac = jax.vmap(compute_jac_with_perturb)(factor) (num_factor,) = factor._get_batch_axes() assert stacked_jac.shape == ( num_factor, @@ -115,7 +117,6 @@ def make( # Fields we want to populate. stacked_factors: list[Factor] = [] - stacked_factor_var_indices: list[dict[type[Var], jax.Array]] = [] jacobian_coords: list[SparseCooCoordinates] = [] # Create storage layout: this describes which parts of our tangent @@ -162,72 +163,25 @@ def make( # # These should be N pairs of (row, col) indices, where rows correspond to # residual indices and columns correspond to tangent vector indices. - single_residual_dim = stacked_factor.residual_dim - stacked_residual_dim = single_residual_dim * len(group) - - subset_indices_from_type = dict[type[Var], jax.Array]() - - for var_type, ids in tangent_ordering.ordered_dict_items( - stacked_factor.sorted_ids_from_var_type - ): - logger.info("Making initial grid") - # Jacobian of a single residual vector with respect to a - # single variable. Upper-left corner at (0, 0). - jac_coords = onp.mgrid[:single_residual_dim, : var_type.tangent_dim] - assert jac_coords.shape == ( - 2, - single_residual_dim, - var_type.tangent_dim, + rows, cols = jax.vmap( + functools.partial( + Factor._compute_block_sparse_jacobian_indices, + tangent_ordering=tangent_ordering, + sorted_ids_from_var_type=sorted_ids_from_var_type, + tangent_start_from_var_type=tangent_start_from_var_type, ) - - logger.info("Getting indices") - - # Get index of each variable, which is based on the sorted IDs. - var_indices = jnp.searchsorted(sorted_ids_from_var_type[var_type], ids) - subset_indices_from_type[var_type] = cast(jax.Array, var_indices) - assert var_indices.shape == (len(group), ids.shape[-1]) - - logger.info("Computing tangent") - # Variable index => indices into the tangent vector. - tangent_start_indices = ( - tangent_start_from_var_type[var_type] - + var_indices * var_type.tangent_dim - ) - assert tangent_start_indices.shape == (len(group), ids.shape[-1]) - - logger.info("Broadcasting") - jac_coords = jnp.broadcast_to( - jac_coords[:, None, :, None, :], - ( - 2, - len(group), - single_residual_dim, - ids.shape[-1], - var_type.tangent_dim, - ), - ) - logger.info( - "Computed indices for Jacobian block with shape {}", - jac_coords.shape, - ) - jacobian_coords.append( - SparseCooCoordinates( - rows=( - jac_coords[0] - + ( - onp.arange(len(group)) * single_residual_dim - + residual_offset - )[:, None, None, None] - ).flatten(), - # Offset the column indices by the start index within the - # flattened tangent vector. - cols=( - jac_coords[1] + tangent_start_indices[:, None, :, None] - ).flatten(), - ) - ) - stacked_factor_var_indices.append(subset_indices_from_type) - residual_offset += stacked_residual_dim + )(stacked_factor) + assert ( + rows.shape + == cols.shape + == (len(group), stacked_factor.residual_dim, rows.shape[-1]) + ) + rows = rows + ( + jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim + ) + rows = rows + residual_offset + jacobian_coords.append(SparseCooCoordinates(rows.flatten(), cols.flatten())) + residual_offset += stacked_factor.residual_dim * len(group) jacobian_coords_concat: SparseCooCoordinates = jax.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords @@ -235,7 +189,6 @@ def make( logger.info("Done!") return StackedFactorGraph( stacked_factors=tuple(stacked_factors), - stacked_factors_var_indices=tuple(stacked_factor_var_indices), jacobian_coords=jacobian_coords_concat, tangent_ordering=tangent_ordering, residual_dim=residual_offset, @@ -331,3 +284,33 @@ def _get_function_signature(func: Callable) -> Hashable: def _get_batch_axes(self) -> tuple[int, ...]: return next(iter(self.sorted_ids_from_var_type.values())).shape[:-1] + + def _compute_block_sparse_jacobian_indices( + self: Factor, + tangent_ordering: VarTypeOrdering, + sorted_ids_from_var_type: dict[type[Var], jax.Array], + tangent_start_from_var_type: dict[type[Var], int], + ) -> tuple[jax.Array, jax.Array]: + """Compute row and column indices for block-sparse Jacobian of shape + (residual dim, total tangent dim). Residual indices will start at row=0.""" + col_indices = list[jax.Array]() + for var_type, ids in tangent_ordering.ordered_dict_items( + self.sorted_ids_from_var_type + ): + var_indices = jnp.searchsorted(sorted_ids_from_var_type[var_type], ids) + tangent_start = tangent_start_from_var_type[var_type] + tangent_indices = ( + onp.arange(tangent_start, tangent_start + var_type.tangent_dim)[None, :] + + var_indices[:, None] * var_type.tangent_dim + ) + assert tangent_indices.shape == ( + var_indices.shape[0], + var_type.tangent_dim, + ) + col_indices.append(cast(jax.Array, tangent_indices).flatten()) + rows, cols = jnp.meshgrid( + jnp.arange(self.residual_dim), + jnp.concatenate(col_indices, axis=0), + indexing="ij", + ) + return rows, cols diff --git a/jaxfg2/_sparse_matrices.py b/jaxfg2/_sparse_matrices.py index 291c7d3..6b21fe6 100644 --- a/jaxfg2/_sparse_matrices.py +++ b/jaxfg2/_sparse_matrices.py @@ -4,6 +4,26 @@ from jax import numpy as jnp +@jdc.pytree_dataclass +class SparseCsrCoordinates: + row_starts: jax.Array + """Index into `cols` for the start of each row.""" + cols: jax.Array + """Column indices of non-zero entries. Shape should be `(*, N)`.""" + + +@jdc.pytree_dataclass +class SparseCsrMatrix: + """Sparse matrix in COO form.""" + + values: jax.Array + """Non-zero matrix values. Shape should be `(*, N)`.""" + coords: SparseCsrCoordinates + """Row and column indices of non-zero entries. Shapes should be `(*, N)`.""" + shape: jdc.Static[tuple[int, int]] + """Shape of matrix.""" + + @jdc.pytree_dataclass class SparseCooCoordinates: rows: jax.Array diff --git a/scripts/pose_graph_g2o.py b/scripts/pose_graph_g2o.py index c62c705..b888576 100755 --- a/scripts/pose_graph_g2o.py +++ b/scripts/pose_graph_g2o.py @@ -32,7 +32,7 @@ def main( jax.block_until_ready(graph) with jaxfg2.utils.stopwatch("Making solver"): - solver = jaxfg2.GaussNewtonSolver(verbose=True, linear_solver=jaxfg2.ConjugateGradientSolver()) + solver = jaxfg2.GaussNewtonSolver(verbose=True) #, linear_solver=jaxfg2.ConjugateGradientSolver()) initial_vals = jaxfg2.VarValues.make(g2o.pose_vars, g2o.initial_poses) with jaxfg2.utils.stopwatch("Running solve"):