Skip to content

Commit

Permalink
- Adding use_exact_inverses argument to optimizer
Browse files Browse the repository at this point in the history
- Disabling caching of inverses when inverse_update_period=1, which should save memory

PiperOrigin-RevId: 476480160
  • Loading branch information
james-martens authored and KfacJaxDev committed Sep 29, 2022
1 parent e3da816 commit cb057bd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
4 changes: 2 additions & 2 deletions kfac_jax/_src/curvature_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,8 +1302,8 @@ def filter_outputs(thunk, vals):
# We must precompute the matches outside of the thunk itself, as the
# thunk will be traced separately from the current compiled context
# (since it's called within a lax.switch statement).
matches = jax.tree_util.tree_map(lambda o, v: o is v,
thunk(), vals)
matches = jax.tree_util.tree_map(lambda o, v: o is v, thunk(), vals)

def new_thunk():
return jax.tree_util.tree_map(lambda o, m: None if m else o,
thunk(), matches)
Expand Down
33 changes: 27 additions & 6 deletions kfac_jax/_src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
estimation_mode: str = "fisher_gradients",
curvature_ema: chex.Numeric = 0.95,
inverse_update_period: int = 5,
use_exact_inverses: bool = False,
batch_process_func: Optional[Callable[[utils.Batch], utils.Batch]] = None,
register_only_generic: bool = False,
patterns_to_skip: Sequence[str] = (),
Expand Down Expand Up @@ -234,6 +235,10 @@ def __init__(
estimate moving averages. (Default: ``0.95``)
inverse_update_period: Int. The number of steps in between updating the
the computation of the inverse curvature approximation. (Default: ``5``)
use_exact_inverses: Bool. If ``True``, preconditioner inverses are
computed "exactly" without the pi-adjusted factored damping approach.
Note that this involves the use of eigendecompositions, which can
sometimes be much more expensive. (Default: ``False``)
batch_process_func: Callable. A function which to be called on each batch
before feeding to the KFAC on device. This could be useful for specific
device input optimizations. (Default: ``None``)
Expand Down Expand Up @@ -364,6 +369,19 @@ def schedule_with_first_step_zero(global_step: chex.Array) -> chex.Array:
self._include_per_param_norms_in_stats = include_per_param_norms_in_stats
self._batch_size_extractor = batch_size_extractor

self._use_cached_inverses = (self._inverse_update_period != 1)
self._use_exact_inverses = use_exact_inverses

if self._use_exact_inverses and self._use_cached_inverses:
self._exact_powers_to_cache = -1
else:
self._exact_powers_to_cache = None

if not self._use_exact_inverses and self._use_cached_inverses:
self._approx_powers_to_cache = -1
else:
self._approx_powers_to_cache = None

# Curvature estimator
self._estimator = curvature_estimator.BlockDiagonalCurvature(
func=self._value_func,
Expand Down Expand Up @@ -589,8 +607,8 @@ def _maybe_update_inverse_cache(
functools.partial(
self.estimator.update_cache,
identity_weight=self.l2_reg + state.damping,
exact_powers=None,
approx_powers=-1,
exact_powers=self._exact_powers_to_cache,
approx_powers=self._approx_powers_to_cache,
eigenvalues=False,
pmap_axis_name=self.pmap_axis_name,
),
Expand All @@ -612,8 +630,8 @@ def _compute_preconditioned_gradient(
state=state.estimator_state,
parameter_structured_vector=grads,
identity_weight=self.l2_reg + state.damping,
exact_power=False,
use_cached=True,
exact_power=self._use_exact_inverses,
use_cached=self._use_cached_inverses,
pmap_axis_name=self.pmap_axis_name,
)

Expand Down Expand Up @@ -726,6 +744,7 @@ def _init(
func_state: Optional[utils.FuncState] = None,
) -> "Optimizer.State":
"""A staged function to initialize the optimizer state ."""

return Optimizer.State(
velocities=jax.tree_util.tree_map(jnp.zeros_like, params),
estimator_state=self.estimator.init(
Expand All @@ -738,8 +757,8 @@ def _init(
has_state=self._value_func_has_state,
has_rng=self._value_func_has_rng,
),
exact_powers_to_cache=None,
approx_powers_to_cache=-1,
exact_powers_to_cache=self._exact_powers_to_cache,
approx_powers_to_cache=self._approx_powers_to_cache,
cache_eigenvalues=False
),
damping=(jnp.array(self._initial_damping)
Expand All @@ -765,8 +784,10 @@ def init(
func_state: Optional[utils.FuncState] = None,
) -> "Optimizer.State":
"""Initializes the optimizer and returns the appropriate optimizer state."""

if not self.finalized:
self.finalize(params, rng, batch, func_state)

return self._init(params, rng, batch, func_state)

@functools.partial(utils.staged, donate_argnums=[1, 3, 5])
Expand Down

0 comments on commit cb057bd

Please sign in to comment.