Skip to content

Commit

Permalink
Shorter + correct Jacobian index calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 7, 2024
1 parent a4f2899 commit e066104
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 77 deletions.
135 changes: 59 additions & 76 deletions jaxfg2/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import dis
import functools
import inspect
import linecache
from typing import Callable, Hashable, Iterable, Mapping, Self, cast
Expand All @@ -20,7 +21,6 @@
@jdc.pytree_dataclass
class StackedFactorGraph:
stacked_factors: tuple[Factor, ...]
stacked_factors_var_indices: tuple[dict[type[Var], jax.Array], ...]
jacobian_coords: SparseCooCoordinates
tangent_ordering: jdc.Static[VarTypeOrdering]
residual_dim: jdc.Static[int]
Expand All @@ -37,14 +37,16 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:

def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix:
jac_vals = []
for factor, subset_indices in zip(
self.stacked_factors, self.stacked_factors_var_indices
):
for factor in self.stacked_factors:
# Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim).
def compute_jac_with_perturb(
factor: Factor, indices_from_type: dict[type[Var], jax.Array]
) -> jax.Array:
val_subset = vals._get_subset(indices_from_type, self.tangent_ordering)
def compute_jac_with_perturb(factor: Factor) -> jax.Array:
val_subset = vals._get_subset(
{
var_type: jnp.searchsorted(vals.ids_from_type[var_type], ids)
for var_type, ids in factor.sorted_ids_from_var_type.items()
},
self.tangent_ordering,
)
# Shape should be: (single_residual_dim, vars * tangent_dim).
return (
# Getting the Jacobian of...
Expand All @@ -59,7 +61,7 @@ def compute_jac_with_perturb(
)
)(jnp.zeros((val_subset._get_tangent_dim(),)))

stacked_jac = jax.vmap(compute_jac_with_perturb)(factor, subset_indices)
stacked_jac = jax.vmap(compute_jac_with_perturb)(factor)
(num_factor,) = factor._get_batch_axes()
assert stacked_jac.shape == (
num_factor,
Expand Down Expand Up @@ -115,7 +117,6 @@ def make(

# Fields we want to populate.
stacked_factors: list[Factor] = []
stacked_factor_var_indices: list[dict[type[Var], jax.Array]] = []
jacobian_coords: list[SparseCooCoordinates] = []

# Create storage layout: this describes which parts of our tangent
Expand Down Expand Up @@ -162,80 +163,32 @@ def make(
#
# These should be N pairs of (row, col) indices, where rows correspond to
# residual indices and columns correspond to tangent vector indices.
single_residual_dim = stacked_factor.residual_dim
stacked_residual_dim = single_residual_dim * len(group)

subset_indices_from_type = dict[type[Var], jax.Array]()

for var_type, ids in tangent_ordering.ordered_dict_items(
stacked_factor.sorted_ids_from_var_type
):
logger.info("Making initial grid")
# Jacobian of a single residual vector with respect to a
# single variable. Upper-left corner at (0, 0).
jac_coords = onp.mgrid[:single_residual_dim, : var_type.tangent_dim]
assert jac_coords.shape == (
2,
single_residual_dim,
var_type.tangent_dim,
rows, cols = jax.vmap(
functools.partial(
Factor._compute_block_sparse_jacobian_indices,
tangent_ordering=tangent_ordering,
sorted_ids_from_var_type=sorted_ids_from_var_type,
tangent_start_from_var_type=tangent_start_from_var_type,
)

logger.info("Getting indices")

# Get index of each variable, which is based on the sorted IDs.
var_indices = jnp.searchsorted(sorted_ids_from_var_type[var_type], ids)
subset_indices_from_type[var_type] = cast(jax.Array, var_indices)
assert var_indices.shape == (len(group), ids.shape[-1])

logger.info("Computing tangent")
# Variable index => indices into the tangent vector.
tangent_start_indices = (
tangent_start_from_var_type[var_type]
+ var_indices * var_type.tangent_dim
)
assert tangent_start_indices.shape == (len(group), ids.shape[-1])

logger.info("Broadcasting")
jac_coords = jnp.broadcast_to(
jac_coords[:, None, :, None, :],
(
2,
len(group),
single_residual_dim,
ids.shape[-1],
var_type.tangent_dim,
),
)
logger.info(
"Computed indices for Jacobian block with shape {}",
jac_coords.shape,
)
jacobian_coords.append(
SparseCooCoordinates(
rows=(
jac_coords[0]
+ (
onp.arange(len(group)) * single_residual_dim
+ residual_offset
)[:, None, None, None]
).flatten(),
# Offset the column indices by the start index within the
# flattened tangent vector.
cols=(
jac_coords[1] + tangent_start_indices[:, None, :, None]
).flatten(),
)
)
stacked_factor_var_indices.append(subset_indices_from_type)
residual_offset += stacked_residual_dim
)(stacked_factor)
assert (
rows.shape
== cols.shape
== (len(group), stacked_factor.residual_dim, rows.shape[-1])
)
rows = rows + (
jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim
)
rows = rows + residual_offset
jacobian_coords.append(SparseCooCoordinates(rows.flatten(), cols.flatten()))
residual_offset += stacked_factor.residual_dim * len(group)

jacobian_coords_concat: SparseCooCoordinates = jax.tree_map(
lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords
)
logger.info("Done!")
return StackedFactorGraph(
stacked_factors=tuple(stacked_factors),
stacked_factors_var_indices=tuple(stacked_factor_var_indices),
jacobian_coords=jacobian_coords_concat,
tangent_ordering=tangent_ordering,
residual_dim=residual_offset,
Expand Down Expand Up @@ -331,3 +284,33 @@ def _get_function_signature(func: Callable) -> Hashable:
def _get_batch_axes(self) -> tuple[int, ...]:
return next(iter(self.sorted_ids_from_var_type.values())).shape[:-1]
def _compute_block_sparse_jacobian_indices(
self: Factor,
tangent_ordering: VarTypeOrdering,
sorted_ids_from_var_type: dict[type[Var], jax.Array],
tangent_start_from_var_type: dict[type[Var], int],
) -> tuple[jax.Array, jax.Array]:
"""Compute row and column indices for block-sparse Jacobian of shape
(residual dim, total tangent dim). Residual indices will start at row=0."""
col_indices = list[jax.Array]()
for var_type, ids in tangent_ordering.ordered_dict_items(
self.sorted_ids_from_var_type
):
var_indices = jnp.searchsorted(sorted_ids_from_var_type[var_type], ids)
tangent_start = tangent_start_from_var_type[var_type]
tangent_indices = (
onp.arange(tangent_start, tangent_start + var_type.tangent_dim)[None, :]
+ var_indices[:, None] * var_type.tangent_dim
)
assert tangent_indices.shape == (
var_indices.shape[0],
var_type.tangent_dim,
)
col_indices.append(cast(jax.Array, tangent_indices).flatten())
rows, cols = jnp.meshgrid(
jnp.arange(self.residual_dim),
jnp.concatenate(col_indices, axis=0),
indexing="ij",
)
return rows, cols
20 changes: 20 additions & 0 deletions jaxfg2/_sparse_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@
from jax import numpy as jnp


@jdc.pytree_dataclass
class SparseCsrCoordinates:
row_starts: jax.Array
"""Index into `cols` for the start of each row."""
cols: jax.Array
"""Column indices of non-zero entries. Shape should be `(*, N)`."""


@jdc.pytree_dataclass
class SparseCsrMatrix:
"""Sparse matrix in COO form."""

values: jax.Array
"""Non-zero matrix values. Shape should be `(*, N)`."""
coords: SparseCsrCoordinates
"""Row and column indices of non-zero entries. Shapes should be `(*, N)`."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""


@jdc.pytree_dataclass
class SparseCooCoordinates:
rows: jax.Array
Expand Down
2 changes: 1 addition & 1 deletion scripts/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(
jax.block_until_ready(graph)

with jaxfg2.utils.stopwatch("Making solver"):
solver = jaxfg2.GaussNewtonSolver(verbose=True, linear_solver=jaxfg2.ConjugateGradientSolver())
solver = jaxfg2.GaussNewtonSolver(verbose=True) #, linear_solver=jaxfg2.ConjugateGradientSolver())
initial_vals = jaxfg2.VarValues.make(g2o.pose_vars, g2o.initial_poses)

with jaxfg2.utils.stopwatch("Running solve"):
Expand Down

0 comments on commit e066104

Please sign in to comment.