Skip to content

Commit

Permalink
(breaking) Revisit API for VarValues.make()
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 24, 2024
1 parent bb32b1c commit b7f439b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 29 deletions.
7 changes: 6 additions & 1 deletion examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def main(
jax.block_until_ready(graph)

with jaxls.utils.stopwatch("Making solver"):
initial_vals = jaxls.VarValues.make(g2o.pose_vars, g2o.initial_poses)
initial_vals = jaxls.VarValues.make(
(
pose_var.with_value(pose)
for pose_var, pose in zip(g2o.pose_vars, g2o.initial_poses)
)
)

with jaxls.utils.stopwatch("Running solve"):
solution_vals = graph.solve(initial_vals, trust_region=None)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def make(

ids = next(iter(factor.sorted_ids_from_var_type.values()))
if len(ids.shape) == 1:
factor = jax.tree.map(lambda x: x[None], factor)
factor = jax.tree.map(lambda x: jnp.asarray(x)[None], factor)
count_from_group[group_key] += 1
else:
assert len(ids.shape) == 2
Expand Down
78 changes: 51 additions & 27 deletions src/jaxls/_variables.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass
from functools import total_ordering
from typing import Any, Callable, ClassVar, Iterable, Literal, cast, overload
from typing import Any, Callable, ClassVar, cast, overload

import jax
import jax_dataclasses as jdc
Expand Down Expand Up @@ -52,11 +53,20 @@ def ordered_dict_items[T](
)


@jdc.pytree_dataclass
class VarWithValue[T]:
"""Structure containing a single variable with a value, or multiple if a
leading batch axis is present. Returned by `Var.with_value()`."""

variable: Var[T]
value: T


@jdc.pytree_dataclass
class Var[T](metaclass=_HashableSortableMeta):
"""A symbolic representation of an optimization variable."""

id: int | jax.Array
id: jax.Array | int

# We would ideally annotate these as `ClassVar[T]`, but we can't.
#
Expand All @@ -68,6 +78,11 @@ class Var[T](metaclass=_HashableSortableMeta):
retract_fn: ClassVar[Callable[[Any, jax.Array], Any]]
"""Retraction function for the manifold. None for Euclidean space."""

def with_value(self, value: T) -> VarWithValue[T]:
"""Assign a value to this variable. Returned value can be used as input
for `VarValues.make()`."""
return VarWithValue(self, value)

@overload
def __init_subclass__[T_](
cls,
Expand Down Expand Up @@ -176,33 +191,42 @@ def __repr__(self) -> str:
return f"VarValues(\n{'\n'.join(out_lines)}\n)"
@staticmethod
def make[T](
variables: Iterable[Var[T]],
values: Iterable[T] | Literal["default"] = "default",
) -> VarValues:
"""Create a `VarValues` object. Entries in `vars` and entries in
`values` have a 1:1 correspondence.

We don't use a {var: value} dictionary because variables are not
hashable.
def make(variables: Iterable[Var[Any] | VarWithValue[Any]]) -> VarValues:
"""Create a VarValues object from a list of variables with or without
values assigned to them. In the latter case, value are set to the
default value of the variable type.

Example:
>>> v1 = SomeVar(1)
>>> v2 = AnotherVar(2)
>>> values = VarValues.make(v1, v2.with_value(custom_value))
"""
variables = tuple(variables)
if values == "default":
ids_from_type = sort_and_stack_vars(variables)
vals_from_type = {
# This should be faster than jnp.stack().
var_type: jax.tree_map(
lambda x: jnp.broadcast_to(x[None, ...], (len(ids), *x.shape)),
var_type.default,
vars = list[Var[Any]]()
vals = list[Any]()
for v in variables:
if isinstance(v, Var):
# Default value.
ids = v.id
assert isinstance(ids, int) or len(ids.shape) in (0, 1)
vars.append(v)
vals.append(
v.default
if isinstance(ids, int) or len(ids.shape) == 0
else jax.tree_map(
lambda x: jnp.broadcast_to(
x[None, ...],
(len(cast(jax.Array, ids).shape), *x.shape),
),
v.default,
)
)
for var_type, ids in ids_from_type.items()
}
return VarValues(vals_from_type=vals_from_type, ids_from_type=ids_from_type)
else:
values = tuple(values)
assert len(variables) == len(values)
ids_from_type, vals_from_type = sort_and_stack_vars(variables, values)
return VarValues(vals_from_type=vals_from_type, ids_from_type=ids_from_type)
else:
# Assigned value.
vars.append(v.variable)
vals.append(v.value)
ids_from_type, vals_from_type = sort_and_stack_vars(tuple(vars), tuple(vals))
return VarValues(vals_from_type=vals_from_type, ids_from_type=ids_from_type)
def _get_subset(
self,
Expand Down

0 comments on commit b7f439b

Please sign in to comment.