Skip to content

Commit

Permalink
Compute CSR sparse matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 7, 2024
1 parent e066104 commit 8cb3069
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 88 deletions.
82 changes: 52 additions & 30 deletions jaxfg2/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import functools
import inspect
import linecache
from typing import Callable, Hashable, Iterable, Mapping, Self, cast
from typing import Callable, Hashable, Iterable, Literal, Mapping, Self, cast

import jax
import jax_dataclasses as jdc
Expand All @@ -14,49 +14,58 @@
from jax import numpy as jnp
from loguru import logger

from ._sparse_matrices import SparseCooCoordinates, SparseCooMatrix
from ._sparse_matrices import (
SparseCooCoordinates,
SparseCooMatrix,
SparseCsrCoordinates,
)
from ._variables import Var, VarTypeOrdering, VarValues, sort_and_stack_vars


@jdc.pytree_dataclass
class StackedFactorGraph:
stacked_factors: tuple[Factor, ...]
jacobian_coords: SparseCooCoordinates
tangent_ordering: jdc.Static[VarTypeOrdering]
residual_dim: jdc.Static[int]
_stacked_factors: tuple[Factor, ...]
_jacobian_coords_coo: SparseCooCoordinates
_jacobian_coords_csr: SparseCsrCoordinates
_tangent_ordering: jdc.Static[VarTypeOrdering]
_residual_dim: jdc.Static[int]

def compute_residual_vector(self, vals: VarValues) -> jax.Array:
residual_slices = list[jax.Array]()
for stacked_factor in self.stacked_factors:
for stacked_factor in self._stacked_factors:
stacked_residual_slice = jax.vmap(
lambda args: stacked_factor.compute_residual(vals, *args)
)(stacked_factor.args)
assert len(stacked_residual_slice.shape) == 2
residual_slices.append(stacked_residual_slice.reshape((-1,)))
return jnp.concatenate(residual_slices, axis=0)

def _compute_jacobian_wrt_tangent(self, vals: VarValues) -> SparseCooMatrix:
def _compute_jacobian_values(self, vals: VarValues) -> jax.Array:
jac_vals = []
for factor in self.stacked_factors:
for factor in self._stacked_factors:
# Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim).
def compute_jac_with_perturb(factor: Factor) -> jax.Array:
val_subset = vals._get_subset(
{
var_type: jnp.searchsorted(vals.ids_from_type[var_type], ids)
for var_type, ids in factor.sorted_ids_from_var_type.items()
},
self.tangent_ordering,
self._tangent_ordering,
)
# Shape should be: (single_residual_dim, vars * tangent_dim).
return (
# Getting the Jacobian of...
jax.jacrev
if factor.residual_dim < val_subset._get_tangent_dim()
if (
factor.jacobian_mode == "auto"
and factor.residual_dim < val_subset._get_tangent_dim()
or factor.jacobian_mode == "reverse"
)
else jax.jacfwd
)(
# The residual function, with respect to to some local delta.
lambda tangent: factor.compute_residual(
val_subset._retract(tangent, self.tangent_ordering),
val_subset._retract(tangent, self._tangent_ordering),
*factor.args,
)
)(jnp.zeros((val_subset._get_tangent_dim(),)))
Expand All @@ -70,12 +79,8 @@ def compute_jac_with_perturb(factor: Factor) -> jax.Array:
)
jac_vals.append(stacked_jac.flatten())
jac_vals = jnp.concatenate(jac_vals, axis=0)
assert jac_vals.shape == (self.jacobian_coords.rows.shape[0],)
return SparseCooMatrix(
values=jac_vals,
coords=self.jacobian_coords,
shape=(self.residual_dim, vals._get_tangent_dim()),
)
assert jac_vals.shape == (self._jacobian_coords_coo.rows.shape[0],)
return jac_vals

@staticmethod
def make(
Expand Down Expand Up @@ -117,7 +122,7 @@ def make(

# Fields we want to populate.
stacked_factors: list[Factor] = []
jacobian_coords: list[SparseCooCoordinates] = []
jacobian_coords: list[tuple[jax.Array, jax.Array]] = []

# Create storage layout: this describes which parts of our tangent
# vector is allocated to each variable.
Expand All @@ -126,10 +131,10 @@ def make(
vars_from_var_type = dict[type[Var], list[Var]]()
for var in vars:
vars_from_var_type.setdefault(type(var), []).append(var)
counter = 0
tangent_offset = 0
for var_type, vars_one_type in vars_from_var_type.items():
tangent_start_from_var_type[var_type] = counter
counter += var_type.tangent_dim * len(vars_one_type)
tangent_start_from_var_type[var_type] = tangent_offset
tangent_offset += var_type.tangent_dim * len(vars_one_type)

# Create ordering helper.
tangent_ordering = VarTypeOrdering(
Expand Down Expand Up @@ -180,18 +185,31 @@ def make(
jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim
)
rows = rows + residual_offset
jacobian_coords.append(SparseCooCoordinates(rows.flatten(), cols.flatten()))
jacobian_coords.append((rows.flatten(), cols.flatten()))
residual_offset += stacked_factor.residual_dim * len(group)

jacobian_coords_concat: SparseCooCoordinates = jax.tree_map(
lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords
jacobian_coords_concat: SparseCooCoordinates = SparseCooCoordinates(
*jax.tree_map(
lambda *arrays: jnp.concatenate(arrays, axis=0), *jacobian_coords
),
shape=(residual_offset, tangent_offset),
)
logger.info("Done!")
return StackedFactorGraph(
stacked_factors=tuple(stacked_factors),
jacobian_coords=jacobian_coords_concat,
tangent_ordering=tangent_ordering,
residual_dim=residual_offset,
_stacked_factors=tuple(stacked_factors),
_jacobian_coords_coo=jacobian_coords_concat,
_jacobian_coords_csr=SparseCsrCoordinates(
indices=jacobian_coords_concat.cols,
indptr=cast(
jax.Array,
jnp.searchsorted(
jacobian_coords_concat.rows, jnp.arange(residual_offset + 1)
),
),
shape=(residual_offset, tangent_offset),
),
_tangent_ordering=tangent_ordering,
_residual_dim=residual_offset,
)


Expand All @@ -208,11 +226,13 @@ class Factor[*Args]:
num_vars: jdc.Static[int]
sorted_ids_from_var_type: dict[type[Var], jax.Array]
residual_dim: jdc.Static[int]
jacobian_mode: jdc.Static[Literal["auto", "forward", "reverse"]]

@staticmethod
def make[*Args_](
compute_residual: Callable[[VarValues, *Args_], jax.Array],
args: tuple[*Args_],
jacobian_mode: Literal["auto", "forward", "reverse"] = "auto",
) -> Factor[*Args_]:
"""Construct a factor for our factor graph."""
# If we see two functions with the same signature, we always use the
Expand All @@ -223,13 +243,14 @@ def make[*Args_](
Factor._get_function_signature(compute_residual), compute_residual
),
)
return Factor._make_impl(compute_residual, args)
return Factor._make_impl(compute_residual, args, jacobian_mode)
@staticmethod
@jdc.jit
def _make_impl[*Args_](
compute_residual: jdc.Static[Callable[[VarValues, *Args_], jax.Array]],
args: tuple[*Args_],
jacobian_mode: jdc.Static[Literal["auto", "forward", "reverse"]],
) -> Factor[*Args_]:
"""Construct a factor for our factor graph."""
Expand Down Expand Up @@ -267,6 +288,7 @@ def _make_impl[*Args_](
tuple(cast(Var, args[i]) for i in variable_indices)
),
residual_dim=_residual_dim_cache[residual_dim_cache_key],
jacobian_mode=jacobian_mode,
)
@staticmethod
Expand Down
44 changes: 25 additions & 19 deletions jaxfg2/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import jax.flatten_util
import jax_dataclasses as jdc
import scipy
import scipy.sparse
import sksparse.cholmod
from jax import numpy as jnp

from jaxfg2.utils import jax_log

from ._factor_graph import StackedFactorGraph
from ._sparse_matrices import SparseCooMatrix
from ._sparse_matrices import SparseCooMatrix, SparseCsrMatrix
from ._variables import VarTypeOrdering, VarValues

# Linear solvers.
Expand All @@ -25,13 +26,8 @@
@jdc.pytree_dataclass
class CholmodSolver:
def solve(
self,
A: SparseCooMatrix,
ATb: jax.Array,
lambd: float | jax.Array,
iterations: int | jax.Array,
self, A: SparseCsrMatrix, ATb: jax.Array, lambd: float | jax.Array
) -> jax.Array:
del iterations
return jax.pure_callback(
self._solve_on_host,
ATb, # Result shape/dtype.
Expand All @@ -43,14 +39,16 @@ def solve(

def _solve_on_host(
self,
A: SparseCooMatrix,
A: SparseCsrMatrix,
ATb: jax.Array,
lambd: float | jax.Array,
) -> jax.Array:
A_T = A.T
A_T_scipy = A_T.as_scipy_coo_matrix().tocsc(copy=False)
# Matrix is transposed when we convert CSR to CSC.
A_T_scipy = scipy.sparse.csc_matrix(
(A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1]
)

# Cache sparsity pattern analysis
# Cache sparsity pattern analysis.
self_hash = object.__hash__(self)
if self_hash not in _cholmod_analyze_cache:
_cholmod_analyze_cache[self_hash] = sksparse.cholmod.analyze_AAt(A_T_scipy)
Expand Down Expand Up @@ -107,7 +105,8 @@ def jacobi_preconditioner(x):
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties
maxiter=len(initial_x),
tol=cast(
float, jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1))
float,
jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)),
)
if self.inexact_step_eta is not None
else self.tolerance,
Expand Down Expand Up @@ -256,14 +255,21 @@ def _initialize_state(
def _step(
self, graph: StackedFactorGraph, state: NonlinearSolverState
) -> NonlinearSolverState:
A = graph._compute_jacobian_wrt_tangent(state.vals)
ATb = -(A.T @ state.residual_vector)
jac_values = graph._compute_jacobian_values(state.vals)
A_coo = SparseCooMatrix(jac_values, graph._jacobian_coords_coo)
ATb = -(A_coo.T @ state.residual_vector)

tangent = self.linear_solver.solve(
A, ATb, lambd=0.0, iterations=state.iterations
)
vals = state.vals._retract(tangent, graph.tangent_ordering)
if isinstance(self.linear_solver, ConjugateGradientSolver):
tangent = self.linear_solver.solve(
A_coo, ATb, lambd=0.0, iterations=state.iterations
)
elif isinstance(self.linear_solver, CholmodSolver):
A_csr = SparseCsrMatrix(jac_values, graph._jacobian_coords_csr)
tangent = self.linear_solver.solve(A_csr, ATb, lambd=0.0)
else:
assert False

vals = state.vals._retract(tangent, graph._tangent_ordering)
if self.verbose:
jax_log(
"Gauss-Newton step #{i}: cost={cost:.4f}",
Expand All @@ -279,7 +285,7 @@ def _step(
state,
cost_updated=state_next.cost,
tangent=tangent,
tangent_ordering=graph.tangent_ordering,
tangent_ordering=graph._tangent_ordering,
ATb=ATb,
)
return state_next
35 changes: 12 additions & 23 deletions jaxfg2/_sparse_matrices.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
import jax
import jax_dataclasses as jdc
import scipy.sparse
from jax import numpy as jnp


@jdc.pytree_dataclass
class SparseCsrCoordinates:
row_starts: jax.Array
"""Index into `cols` for the start of each row."""
cols: jax.Array
indices: jax.Array
"""Column indices of non-zero entries. Shape should be `(*, N)`."""
indptr: jax.Array
"""Index of start to each row. Shape should be `(*, num_rows)`."""
shape: jdc.Static[tuple[int, int]]


@jdc.pytree_dataclass
class SparseCsrMatrix:
"""Sparse matrix in COO form."""
"""Data structure for sparse CSR matrices."""

values: jax.Array
"""Non-zero matrix values. Shape should be `(*, N)`."""
coords: SparseCsrCoordinates
"""Row and column indices of non-zero entries. Shapes should be `(*, N)`."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""


@jdc.pytree_dataclass
Expand All @@ -30,6 +27,8 @@ class SparseCooCoordinates:
"""Row indices of non-zero entries. Shape should be `(*, N)`."""
cols: jax.Array
"""Column indices of non-zero entries. Shape should be `(*, N)`."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""


@jdc.pytree_dataclass
Expand All @@ -40,45 +39,35 @@ class SparseCooMatrix:
"""Non-zero matrix values. Shape should be `(*, N)`."""
coords: SparseCooCoordinates
"""Row and column indices of non-zero entries. Shapes should be `(*, N)`."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""

def __matmul__(self, other: jax.Array):
"""Compute `Ax`, where `x` is a 1D vector."""
assert other.shape == (
self.shape[1],
self.coords.shape[1],
), "Inner product only supported for 1D vectors!"
return (
jnp.zeros(self.shape[0], dtype=other.dtype)
jnp.zeros(self.coords.shape[0], dtype=other.dtype)
.at[self.coords.rows]
.add(self.values * other[self.coords.cols])
)

def as_dense(self) -> jnp.ndarray:
"""Convert to a dense JAX array."""
# TODO: untested
return (
jnp.zeros(self.shape)
jnp.zeros(self.coords.shape)
.at[self.coords.rows, self.coords.cols]
.set(self.values)
)

@property
def T(self):
"""Return transpose of our sparse matrix."""
h, w = self.shape
h, w = self.coords.shape
return SparseCooMatrix(
values=self.values,
coords=SparseCooCoordinates(
rows=self.coords.cols,
cols=self.coords.rows,
shape=(w, h),
),
shape=(w, h),
)

def as_scipy_coo_matrix(self) -> scipy.sparse.coo_matrix:
"""Convert to a sparse scipy matrix."""
assert len(self.values.shape) == 1
return scipy.sparse.coo_matrix(
(self.values, (self.coords.rows, self.coords.cols)), shape=self.shape
)
Loading

0 comments on commit 8cb3069

Please sign in to comment.