Skip to content

Commit

Permalink
Batching fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 18, 2024
1 parent 4597ea3 commit 8dbbe6e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ dependencies = [
"scikit-sparse",
"loguru",
"termcolor",
"tqdm",
"matplotlib",
]

[project.optional-dependencies]
dev = [
"pyright>=1.1.308",
"tqdm",
"matplotlib",
]
44 changes: 36 additions & 8 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:
def _compute_jac_values(self, vals: VarValues) -> jax.Array:
jac_vals = []
for factor in self.stacked_factors:
# Shape should be: (num_variables, len(group), single_residual_dim, var.tangent_dim).
# Shape should be: (num_variables, length_from_group[group_key], single_residual_dim, var.tangent_dim).
def compute_jac_with_perturb(factor: Factor) -> jax.Array:
val_subset = vals._get_subset(
{
Expand Down Expand Up @@ -125,6 +125,7 @@ def make(

# Start by grouping our factors and grabbing a list of (ordered!) variables
factors_from_group = dict[Any, list[Factor]]()
length_from_group = dict[Any, int]()
for factor in factors:
# Each factor is ultimately just a pytree node; in order for a set of
# factors to be batchable, they must share the same:
Expand All @@ -137,9 +138,19 @@ def make(
for leaf in jax.tree_leaves(factor)
),
)
factors_from_group.setdefault(group_key, [])
length_from_group.setdefault(group_key, 0)

ids = next(iter(factor.sorted_ids_from_var_type.values()))
if len(ids.shape) == 1:
factor = jax.tree.map(lambda x: x[None], factor)
length_from_group[group_key] += 1
else:
assert len(ids.shape) == 2
length_from_group[group_key] += ids.shape[0]

# Record factor and variables.
factors_from_group.setdefault(group_key, []).append(factor)
factors_from_group[group_key].append(factor)

# Fields we want to populate.
stacked_factors: list[Factor] = []
Expand All @@ -157,7 +168,13 @@ def _sort_key(x: Any) -> str:
# Count variables of each type.
count_from_var_type = dict[type[Var[Any]], int]()
for var in variables:
count_from_var_type[type(var)] = count_from_var_type.get(type(var), 0) + 1
if isinstance(var.id, int) or var.id.shape == ():
increment = 1
else:
(increment,) = var.id.shape
count_from_var_type[type(var)] = (
count_from_var_type.get(type(var), 0) + increment
)
tangent_dim_sum = 0
for var_type in sorted(count_from_var_type.keys(), key=_sort_key):
tangent_start_from_var_type[var_type] = tangent_dim_sum
Expand All @@ -181,14 +198,14 @@ def _sort_key(x: Any) -> str:
group = factors_from_group[group_key]
logger.info(
"Group with factors={}, variables={}: {}",
len(group),
length_from_group[group_key],
group[0].num_variables,
group[0].compute_residual.__name__,
)

# Stack factor parameters.
stacked_factor: Factor = jax.tree.map(
lambda *args: jnp.stack(args, axis=0), *group
lambda *args: jnp.concatenate(args, axis=0), *group
)
stacked_factors.append(stacked_factor)

Expand All @@ -207,14 +224,21 @@ def _sort_key(x: Any) -> str:
assert (
rows.shape
== cols.shape
== (len(group), stacked_factor.residual_dim, rows.shape[-1])
== (
length_from_group[group_key],
stacked_factor.residual_dim,
rows.shape[-1],
)
)
rows = rows + (
jnp.arange(len(group))[:, None, None] * stacked_factor.residual_dim
jnp.arange(length_from_group[group_key])[:, None, None]
* stacked_factor.residual_dim
)
rows = rows + residual_dim_sum
jac_coords.append((rows.flatten(), cols.flatten()))
residual_dim_sum += stacked_factor.residual_dim * len(group)
residual_dim_sum += (
stacked_factor.residual_dim * length_from_group[group_key]
)

jac_coords_coo: SparseCooCoordinates = SparseCooCoordinates(
*jax.tree_map(lambda *arrays: jnp.concatenate(arrays, axis=0), *jac_coords),
Expand Down Expand Up @@ -293,11 +317,15 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]:
return variables
variables = tuple(traverse_args(args, []))
assert len(variables) > 0
# Cache the residual dimension for this factor.
residual_dim_cache_key = (
compute_residual,
jax.tree.structure(args),
tuple(
x.shape if hasattr(x, "shape") else None for x in jax.tree.leaves(args)
),
)
if residual_dim_cache_key not in _residual_dim_cache:
dummy = VarValues.make(variables)
Expand Down

0 comments on commit 8dbbe6e

Please sign in to comment.