Skip to content

Commit

Permalink
Fix tracer leak bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 1, 2024
1 parent a2d60e6 commit 7ded980
Showing 1 changed file with 19 additions and 45 deletions.
64 changes: 19 additions & 45 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,18 @@ def make(
else:
from jax import numpy as jnp

factors = tuple(factors)
variables = tuple(variables)
compute_residual_from_hash = dict[Hashable, Callable]()
factors = tuple(
jdc.replace(
factor,
compute_residual=compute_residual_from_hash.setdefault(
factor._get_function_signature(factor.compute_residual),
factor.compute_residual,
),
)
for factor in factors
)

# We're assuming no more than 1 batch axis.
num_factors = 0
Expand Down Expand Up @@ -289,10 +299,6 @@ def _sort_key(x: Any) -> str:
)


_residual_dim_cache = dict[Hashable, int]()
_function_cache = dict[Hashable, Callable]()


@jdc.pytree_dataclass
class Factor[*Args]:
"""A single cost in our factor graph."""
Expand All @@ -304,29 +310,12 @@ class Factor[*Args]:
residual_dim: jdc.Static[int]
jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]]

@staticmethod
def make[*Args_](
compute_residual: Callable[[VarValues, *Args_], jax.Array],
args: tuple[*Args_],
jac_mode: Literal["auto", "forward", "reverse"] = "auto",
) -> Factor[*Args_]:
"""Construct a factor for our factor graph."""
# If we see two functions with the same signature, we always use the
# first one. This helps with vectorization.
compute_residual = cast(
Callable[[VarValues, *Args_], jax.Array],
_function_cache.setdefault(
Factor._get_function_signature(compute_residual), compute_residual
),
)
return Factor._make_impl(compute_residual, args, jac_mode)
@staticmethod
@jdc.jit
def _make_impl[*Args_](
def make[*Args_](
compute_residual: jdc.Static[Callable[[VarValues, *Args_], jax.Array]],
args: tuple[*Args_],
jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]],
jac_mode: jdc.Static[Literal["auto", "forward", "reverse"]] = "reverse",
) -> Factor[*Args_]:
"""Construct a factor for our factor graph."""
Expand Down Expand Up @@ -354,37 +343,22 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]:
() if isinstance(var.id, int) else var.id.shape
) == batch_axes, "Batch axes of variables do not match."
if len(batch_axes) == 1:
return jax.vmap(Factor._make_impl, in_axes=(None, 0, None))(
return jax.vmap(Factor.make, in_axes=(None, 0, None))(
compute_residual, args, jac_mode
)
# Cache the residual dimension for this factor.
residual_dim_cache_key = (
compute_residual,
jax.tree.structure(args),
tuple(
x.shape if hasattr(x, "shape") else None for x in jax.tree.leaves(args)
),
)
if residual_dim_cache_key not in _residual_dim_cache:
dummy = VarValues.make(variables)
residual_shape = jax.eval_shape(compute_residual, dummy, *args).shape
assert len(residual_shape) == 1, "Residual must be a 1D array."
_residual_dim_cache[residual_dim_cache_key] = residual_shape[0]
# Let's not leak too much memory...
MAX_CACHE_SIZE = 512
if len(_function_cache) > MAX_CACHE_SIZE:
_function_cache.pop(next(iter(_function_cache.keys())))
if len(_residual_dim_cache) > MAX_CACHE_SIZE:
_residual_dim_cache.pop(next(iter(_residual_dim_cache.keys())))
dummy_vals = jax.eval_shape(VarValues.make, variables)
residual_shape = jax.eval_shape(compute_residual, dummy_vals, *args).shape
assert len(residual_shape) == 1, "Residual must be a 1D array."
(residual_dim,) = residual_shape
return Factor(
compute_residual,
args=args,
num_variables=len(variables),
sorted_ids_from_var_type=sort_and_stack_vars(variables),
residual_dim=_residual_dim_cache[residual_dim_cache_key],
residual_dim=residual_dim,
jac_mode=jac_mode,
)
Expand Down

0 comments on commit 7ded980

Please sign in to comment.