Skip to content

Commit

Permalink
Tuning, more detailed prints
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 18, 2024
1 parent 8dbbe6e commit 2749541
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
26 changes: 15 additions & 11 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
@jdc.pytree_dataclass
class FactorGraph:
stacked_factors: tuple[Factor, ...]
factor_counts: jdc.Static[tuple[int, ...]]
sorted_ids_from_var_type: dict[type[Var], jax.Array]
jac_coords_coo: SparseCooCoordinates
jac_coords_csr: SparseCsrCoordinates
Expand Down Expand Up @@ -64,7 +65,7 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:
def _compute_jac_values(self, vals: VarValues) -> jax.Array:
jac_vals = []
for factor in self.stacked_factors:
# Shape should be: (num_variables, length_from_group[group_key], single_residual_dim, var.tangent_dim).
# Shape should be: (num_variables, count_from_group[group_key], single_residual_dim, var.tangent_dim).
def compute_jac_with_perturb(factor: Factor) -> jax.Array:
val_subset = vals._get_subset(
{
Expand Down Expand Up @@ -125,7 +126,7 @@ def make(

# Start by grouping our factors and grabbing a list of (ordered!) variables
factors_from_group = dict[Any, list[Factor]]()
length_from_group = dict[Any, int]()
count_from_group = dict[Any, int]()
for factor in factors:
# Each factor is ultimately just a pytree node; in order for a set of
# factors to be batchable, they must share the same:
Expand All @@ -139,22 +140,23 @@ def make(
),
)
factors_from_group.setdefault(group_key, [])
length_from_group.setdefault(group_key, 0)
count_from_group.setdefault(group_key, 0)

ids = next(iter(factor.sorted_ids_from_var_type.values()))
if len(ids.shape) == 1:
factor = jax.tree.map(lambda x: x[None], factor)
length_from_group[group_key] += 1
count_from_group[group_key] += 1
else:
assert len(ids.shape) == 2
length_from_group[group_key] += ids.shape[0]
count_from_group[group_key] += ids.shape[0]

# Record factor and variables.
factors_from_group[group_key].append(factor)

# Fields we want to populate.
stacked_factors: list[Factor] = []
jac_coords: list[tuple[jax.Array, jax.Array]] = []
stacked_factors = list[Factor]()
factor_counts = list[int]()
jac_coords = list[tuple[jax.Array, jax.Array]]()

# Create storage layout: this describes which parts of our tangent
# vector is allocated to each variable.
Expand Down Expand Up @@ -198,7 +200,7 @@ def _sort_key(x: Any) -> str:
group = factors_from_group[group_key]
logger.info(
"Group with factors={}, variables={}: {}",
length_from_group[group_key],
count_from_group[group_key],
group[0].num_variables,
group[0].compute_residual.__name__,
)
Expand All @@ -208,6 +210,7 @@ def _sort_key(x: Any) -> str:
lambda *args: jnp.concatenate(args, axis=0), *group
)
stacked_factors.append(stacked_factor)
factor_counts.append(count_from_group[group_key])

# Compute Jacobian coordinates.
#
Expand All @@ -225,19 +228,19 @@ def _sort_key(x: Any) -> str:
rows.shape
== cols.shape
== (
length_from_group[group_key],
count_from_group[group_key],
stacked_factor.residual_dim,
rows.shape[-1],
)
)
rows = rows + (
jnp.arange(length_from_group[group_key])[:, None, None]
jnp.arange(count_from_group[group_key])[:, None, None]
* stacked_factor.residual_dim
)
rows = rows + residual_dim_sum
jac_coords.append((rows.flatten(), cols.flatten()))
residual_dim_sum += (
stacked_factor.residual_dim * length_from_group[group_key]
stacked_factor.residual_dim * count_from_group[group_key]
)

jac_coords_coo: SparseCooCoordinates = SparseCooCoordinates(
Expand All @@ -255,6 +258,7 @@ def _sort_key(x: Any) -> str:
logger.info("Done!")
return FactorGraph(
stacked_factors=tuple(stacked_factors),
factor_counts=tuple(factor_counts),
sorted_ids_from_var_type=sorted_ids_from_var_type,
jac_coords_coo=jac_coords_coo,
jac_coords_csr=jac_coords_csr,
Expand Down
28 changes: 23 additions & 5 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU."""

tolerance: float = 1e-5
inexact_step_eta: float | None = 1e-2
inexact_step_eta: float | None = None
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.
Expand Down Expand Up @@ -115,7 +115,7 @@ def ATA_function(x: jax.Array):
b=ATb,
x0=initial_x,
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties
maxiter=len(initial_x),
# maxiter=len(initial_x),
tol=cast(
float,
jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)),
Expand Down Expand Up @@ -202,7 +202,25 @@ def step(
i=state.iterations,
cost=state.cost,
lambd=state.lambd,
ordered=True,
)
residual_index = 0
for f, count in zip(graph.stacked_factors, graph.factor_counts):
stacked_dim = count * f.residual_dim
partial_cost = jnp.sum(
state.residual_vector[residual_index : residual_index + stacked_dim]
** 2
)
residual_index += stacked_dim
jax_log(
" - "
+ f"{f.compute_residual.__name__}({count}):".ljust(15)
+ " {:.5f} (avg {:.5f})",
partial_cost,
partial_cost / stacked_dim,
ordered=True,
)

with jdc.copy_and_mutate(state) as state_next:
proposed_residual_vector = graph.compute_residual_vector(vals)
proposed_cost = jnp.sum(proposed_residual_vector**2)
Expand Down Expand Up @@ -273,14 +291,14 @@ class TrustRegionConfig:
class TerminationConfig:
# Termination criteria.
max_iterations: int = 100
cost_tolerance: float = 1e-4
cost_tolerance: float = 1e-6
"""We terminate if `|cost change| / cost < cost_tolerance`."""
gradient_tolerance: float = 1e-6
gradient_tolerance: float = 1e-8
"""We terminate if `norm_inf(x - rplus(x, linear delta)) < gradient_tolerance`."""
gradient_tolerance_start_step: int = 10
"""When to start checking the gradient tolerance condition. Helps solve precision
issues caused by inexact Newton steps."""
parameter_tolerance: float = 1e-5
parameter_tolerance: float = 1e-7
"""We terminate if `norm_2(linear delta) < (norm2(x) + parameter_tolerance) * parameter_tolerance`."""

def _check_convergence(
Expand Down

0 comments on commit 2749541

Please sign in to comment.