From f751dc76ed029ab3a7dffd3035c502e5df736e68 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 27 Sep 2024 09:40:10 +0000 Subject: [PATCH] Improve batch axis support --- README.md | 9 ++++++--- src/jaxls/_factor_graph.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2fdaf9b..4a55e5f 100644 --- a/README.md +++ b/README.md @@ -40,18 +40,21 @@ faster and easier to use. For additional references, see inspirations like For Cholesky factorization via CHOLMOD, `scikit-sparse` requires SuiteSparse: ```bash -# Via conda. +# Option 1: via conda. conda install conda-forge::suitesparse -# Via apt. +# Option 2: via apt. sudo apt update sudo apt install -y libsuitesparse-dev -# Via brew. +# Option 3: via brew. brew install suite-sparse ``` Then, from your environment of choice: ```bash +# Option 1: from git. +pip install git+ssh://git@github.com/brentyi/jaxls.git +# Option 2: editable. git clone https://github.com/brentyi/jaxls.git cd jaxls pip install -e . diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 211f248..899ddab 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -118,10 +118,25 @@ def make( factors = tuple(factors) variables = tuple(variables) + + # We're assuming no more than 1 batch axis. + num_factors = 0 + for f in factors: + assert len(f._get_batch_axes()) in (0, 1) + num_factors += ( + 1 if len(f._get_batch_axes()) == 0 else f._get_batch_axes()[0] + ) + + num_variables = 0 + for v in variables: + assert isinstance(v.id, int) or len(v.id.shape) in (0, 1) + num_variables += ( + 1 if isinstance(v.id, int) or v.id.shape == () else v.id.shape[0] + ) logger.info( "Building graph with {} factors and {} variables.", - len(factors), - len(variables), + num_factors, + num_variables, ) # Start by grouping our factors and grabbing a list of (ordered!) variables @@ -323,6 +338,19 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]: variables = tuple(traverse_args(args, [])) assert len(variables) > 0 + # Support batch axis. + if not isinstance(variables[0].id, int): + batch_axes = variables[0].id.shape + assert len(batch_axes) in (0, 1) + for var in variables[1:]: + assert ( + () if isinstance(var.id, int) else var.id.shape + ) == batch_axes, "Batch axes of variables do not match." + if len(batch_axes) == 1: + return jax.vmap(Factor._make_impl, in_axes=(None, 0, None))( + compute_residual, args, jac_mode + ) + # Cache the residual dimension for this factor. residual_dim_cache_key = ( compute_residual,