diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 6121a3c..13534ab 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -116,8 +116,18 @@ def make( else: from jax import numpy as jnp - factors = tuple(factors) variables = tuple(variables) + compute_residual_from_hash = dict[Hashable, Callable]() + factors = tuple( + jdc.replace( + factor, + compute_residual=compute_residual_from_hash.setdefault( + factor._get_function_signature(factor.compute_residual), + factor.compute_residual, + ), + ) + for factor in factors + ) # We're assuming no more than 1 batch axis. num_factors = 0 @@ -289,10 +299,6 @@ def _sort_key(x: Any) -> str: ) -_residual_dim_cache = dict[Hashable, int]() -_function_cache = dict[Hashable, Callable]() - - @jdc.pytree_dataclass class Factor[*Args]: """A single cost in our factor graph.""" @@ -304,29 +310,12 @@ class Factor[*Args]: residual_dim: jdc.Static[int] jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]] - @staticmethod - def make[*Args_]( - compute_residual: Callable[[VarValues, *Args_], jax.Array], - args: tuple[*Args_], - jac_mode: Literal["auto", "forward", "reverse"] = "auto", - ) -> Factor[*Args_]: - """Construct a factor for our factor graph.""" - # If we see two functions with the same signature, we always use the - # first one. This helps with vectorization. - compute_residual = cast( - Callable[[VarValues, *Args_], jax.Array], - _function_cache.setdefault( - Factor._get_function_signature(compute_residual), compute_residual - ), - ) - return Factor._make_impl(compute_residual, args, jac_mode) - @staticmethod @jdc.jit - def _make_impl[*Args_]( + def make[*Args_]( compute_residual: jdc.Static[Callable[[VarValues, *Args_], jax.Array]], args: tuple[*Args_], - jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]], + jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]] = "reverse", ) -> Factor[*Args_]: """Construct a factor for our factor graph.""" @@ -354,37 +343,22 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]: () if isinstance(var.id, int) else var.id.shape ) == batch_axes, "Batch axes of variables do not match." if len(batch_axes) == 1: - return jax.vmap(Factor._make_impl, in_axes=(None, 0, None))( + return jax.vmap(Factor.make, in_axes=(None, 0, None))( compute_residual, args, jac_mode ) # Cache the residual dimension for this factor. - residual_dim_cache_key = ( - compute_residual, - jax.tree.structure(args), - tuple( - x.shape if hasattr(x, "shape") else None for x in jax.tree.leaves(args) - ), - ) - if residual_dim_cache_key not in _residual_dim_cache: - dummy = VarValues.make(variables) - residual_shape = jax.eval_shape(compute_residual, dummy, *args).shape - assert len(residual_shape) == 1, "Residual must be a 1D array." - _residual_dim_cache[residual_dim_cache_key] = residual_shape[0] - - # Let's not leak too much memory... - MAX_CACHE_SIZE = 512 - if len(_function_cache) > MAX_CACHE_SIZE: - _function_cache.pop(next(iter(_function_cache.keys()))) - if len(_residual_dim_cache) > MAX_CACHE_SIZE: - _residual_dim_cache.pop(next(iter(_residual_dim_cache.keys()))) + dummy_vals = jax.eval_shape(VarValues.make, variables) + residual_shape = jax.eval_shape(compute_residual, dummy_vals, *args).shape + assert len(residual_shape) == 1, "Residual must be a 1D array." + (residual_dim,) = residual_shape return Factor( compute_residual, args=args, num_variables=len(variables), sorted_ids_from_var_type=sort_and_stack_vars(variables), - residual_dim=_residual_dim_cache[residual_dim_cache_key], + residual_dim=residual_dim, jac_mode=jac_mode, )