Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers. #301

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(
),
)

self.params_init = jax.pmap(init_parameters_func)
self.params_init = jax.pmap(init_parameters_func, axis_name="kfac_axis")
self.model_loss_func = model_loss_func
self.model_func_for_estimator = model_func_for_estimator

Expand All @@ -223,10 +223,10 @@ def __init__(
)

self.train_batch_pmap = jax.pmap(
self._train_batch, axis_name="train_axis"
self._train_batch, axis_name="kfac_axis"
)
self.eval_batch_pmap = jax.pmap(
self._eval_batch, axis_name="eval_axis"
self._eval_batch, axis_name="kfac_axis"
)

# Log some useful information
Expand Down