diff --git a/pyproject.toml b/pyproject.toml index c51941f..3692a89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,11 +23,11 @@ dependencies = [ "scikit-sparse", "loguru", "termcolor", + "tqdm", + "matplotlib", ] [project.optional-dependencies] dev = [ "pyright>=1.1.308", - "tqdm", - "matplotlib", ] diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 5d8f83b..83792c3 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -64,7 +64,7 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array: def _compute_jac_values(self, vals: VarValues) -> jax.Array: jac_vals = [] for factor in self.stacked_factors: - # Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim). + # Shape should be: (num_variables, length_from_group[group_key], single_residual_dim, var.tangent_dim). def compute_jac_with_perturb(factor: Factor) -> jax.Array: val_subset = vals._get_subset( { @@ -125,6 +125,7 @@ def make( # Start by grouping our factors and grabbing a list of (ordered!) variables factors_from_group = dict[Any, list[Factor]]() + length_from_group = dict[Any, int]() for factor in factors: # Each factor is ultimately just a pytree node; in order for a set of # factors to be batchable, they must share the same: @@ -137,9 +138,19 @@ def make( for leaf in jax.tree_leaves(factor) ), ) + factors_from_group.setdefault(group_key, []) + length_from_group.setdefault(group_key, 0) + + ids = next(iter(factor.sorted_ids_from_var_type.values())) + if len(ids.shape) == 1: + factor = jax.tree.map(lambda x: x[None], factor) + length_from_group[group_key] += 1 + else: + assert len(ids.shape) == 2 + length_from_group[group_key] += ids.shape[0] # Record factor and variables. - factors_from_group.setdefault(group_key, []).append(factor) + factors_from_group[group_key].append(factor) # Fields we want to populate. stacked_factors: list[Factor] = [] @@ -157,7 +168,13 @@ def _sort_key(x: Any) -> str: # Count variables of each type. count_from_var_type = dict[type[Var[Any]], int]() for var in variables: - count_from_var_type[type(var)] = count_from_var_type.get(type(var), 0) + 1 + if isinstance(var.id, int) or var.id.shape == (): + increment = 1 + else: + (increment,) = var.id.shape + count_from_var_type[type(var)] = ( + count_from_var_type.get(type(var), 0) + increment + ) tangent_dim_sum = 0 for var_type in sorted(count_from_var_type.keys(), key=_sort_key): tangent_start_from_var_type[var_type] = tangent_dim_sum @@ -181,14 +198,14 @@ def _sort_key(x: Any) -> str: group = factors_from_group[group_key] logger.info( "Group with factors={}, variables={}: {}", - len(group), + length_from_group[group_key], group[0].num_variables, group[0].compute_residual.__name__, ) # Stack factor parameters. stacked_factor: Factor = jax.tree.map( - lambda *args: jnp.stack(args, axis=0), *group + lambda *args: jnp.concatenate(args, axis=0), *group ) stacked_factors.append(stacked_factor) @@ -207,14 +224,21 @@ def _sort_key(x: Any) -> str: assert ( rows.shape == cols.shape - == (len(group), stacked_factor.residual_dim, rows.shape[-1]) + == ( + length_from_group[group_key], + stacked_factor.residual_dim, + rows.shape[-1], + ) ) rows = rows + ( - jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim + jnp.arange(length_from_group[group_key])[:, None, None] + * stacked_factor.residual_dim ) rows = rows + residual_dim_sum jac_coords.append((rows.flatten(), cols.flatten())) - residual_dim_sum += stacked_factor.residual_dim * len(group) + residual_dim_sum += ( + stacked_factor.residual_dim * length_from_group[group_key] + ) jac_coords_coo: SparseCooCoordinates = SparseCooCoordinates( *jax.tree_map(lambda *arrays: jnp.concatenate(arrays, axis=0), *jac_coords), @@ -293,11 +317,15 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]: return variables variables = tuple(traverse_args(args, [])) + assert len(variables) > 0 # Cache the residual dimension for this factor. residual_dim_cache_key = ( compute_residual, jax.tree.structure(args), + tuple( + x.shape if hasattr(x, "shape") else None for x in jax.tree.leaves(args) + ), ) if residual_dim_cache_key not in _residual_dim_cache: dummy = VarValues.make(variables)