Skip to content

Commit

Permalink
Remove the TwoKroneckerFactored class and use the `KroneckerFactore…
Browse files Browse the repository at this point in the history
…d` class instead.

This is a no-op change. It just simplifies the code and makes it easier to add support for more than two parameters.

PiperOrigin-RevId: 664755711
Change-Id: Ia231530d7e714a0a44b22babdc1a008de6fe9340
  • Loading branch information
FermiNet Contributor authored and dpfau committed Aug 22, 2024
1 parent 4514bf5 commit 67cf795
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions ferminet/curvature_tags_and_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def fixed_scale(self) -> Numeric:

def update_curvature_matrix_estimate(
self,
state: kfac_jax.TwoKroneckerFactored.State,
state: kfac_jax.KroneckerFactored.State,
estimation_data: kfac_jax.LayerVjpData[Array],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
) -> kfac_jax.KroneckerFactored.State:
[x] = estimation_data.primals.inputs
[dy] = estimation_data.tangents.outputs
assert x.shape[0] == batch_size
Expand Down Expand Up @@ -88,7 +88,7 @@ def update_curvature_matrix_estimate(
)


class QmcBlockedDense(kfac_jax.TwoKroneckerFactored):
class QmcBlockedDense(kfac_jax.KroneckerFactored):
"""A factor that is the Kronecker product of two matrices."""

def input_size(self) -> int:
Expand All @@ -102,13 +102,13 @@ def fixed_scale(self) -> Numeric:

def update_curvature_matrix_estimate(
self,
state: kfac_jax.TwoKroneckerFactored.State,
state: kfac_jax.KroneckerFactored.State,
estimation_data: kfac_jax.LayerVjpData[Array],
ema_old: Numeric,
ema_new: Numeric,
identity_weight: Numeric,
batch_size: int,
) -> kfac_jax.TwoKroneckerFactored.State:
) -> kfac_jax.KroneckerFactored.State:
del identity_weight

[x] = estimation_data.primals.inputs
Expand All @@ -132,7 +132,7 @@ def _init(
exact_powers_to_cache: Set[Scalar],
approx_powers_to_cache: Set[Scalar],
cache_eigenvalues: bool,
) -> kfac_jax.TwoKroneckerFactored.State:
) -> kfac_jax.KroneckerFactored.State:
del rng, cache_eigenvalues
k, m, j, n = self.parameters_shapes[0]
cache = dict()
Expand All @@ -147,7 +147,7 @@ def _init(
inputs_factor=jnp.zeros([j, k, k]),
outputs_factor=jnp.zeros([j, m * n, m * n]),
)
return kfac_jax.TwoKroneckerFactored.State(
return kfac_jax.KroneckerFactored.State(
cache=cache,
inputs_factor=
kfac_jax.utils.WeightedMovingAverage.zeros_array((j, k, k)),
Expand All @@ -157,12 +157,12 @@ def _init(

def _update_cache(
self,
state: kfac_jax.TwoKroneckerFactored.State,
state: kfac_jax.KroneckerFactored.State,
identity_weight: kfac_jax.utils.Numeric,
exact_powers: set[kfac_jax.utils.Scalar],
approx_powers: set[kfac_jax.utils.Scalar],
eigenvalues: bool,
) -> kfac_jax.TwoKroneckerFactored.State:
) -> kfac_jax.KroneckerFactored.State:
del eigenvalues

if exact_powers:
Expand All @@ -186,7 +186,7 @@ def _update_cache(

def multiply_matpower(
self,
state: kfac_jax.TwoKroneckerFactored.State,
state: kfac_jax.KroneckerFactored.State,
vector: Sequence[Array],
identity_weight: Numeric,
power: Scalar,
Expand Down

0 comments on commit 67cf795

Please sign in to comment.