diff --git a/examples/training.py b/examples/training.py index 56e12ed..ca098cd 100644 --- a/examples/training.py +++ b/examples/training.py @@ -211,7 +211,7 @@ def __init__( ), ) - self.params_init = jax.pmap(init_parameters_func) + self.params_init = jax.pmap(init_parameters_func, axis_name="kfac_axis") self.model_loss_func = model_loss_func self.model_func_for_estimator = model_func_for_estimator @@ -223,10 +223,10 @@ def __init__( ) self.train_batch_pmap = jax.pmap( - self._train_batch, axis_name="train_axis" + self._train_batch, axis_name="kfac_axis" ) self.eval_batch_pmap = jax.pmap( - self._eval_batch, axis_name="eval_axis" + self._eval_batch, axis_name="kfac_axis" ) # Log some useful information