Skip to content

Commit

Permalink
Adding out_sharding to test model to fix recent test failure.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718429490
  • Loading branch information
james-martens authored and KfacJaxDev committed Jan 22, 2025
1 parent 9182734 commit d4388bc
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __call__(self, inputs: LayerInputs, *_) -> LayerInputs: # pytype: disable=s
precision=(jax.lax.Precision.HIGHEST, jax.lax.Precision.HIGHEST),
preferred_element_type=preferred_element_type,
out_type=None,
out_sharding=None,
)
layer_values.append((x, y))

Expand Down

0 comments on commit d4388bc

Please sign in to comment.