Skip to content

Commit

Permalink
Fix FlaxLayer docstring. (keras-team#19400)
Browse files Browse the repository at this point in the history
The code samples had typos.

Also removed in incorrect comment in tests, the export is fully tested.
  • Loading branch information
hertschuh authored Mar 29, 2024
1 parent 4adb561 commit a063684
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
4 changes: 2 additions & 2 deletions keras/utils/jax_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def __call__(self, inputs):
```python
class MyFlaxModule(flax.linen.Module):
@flax.linen.compact
def forward(self, input1, input1, deterministic):
def forward(self, input1, input2, deterministic):
...
return outputs
Expand All @@ -497,7 +497,7 @@ def my_flax_module_wrapper(module, inputs, training):
return module.forward(input1, input2, not training)
flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)
keras_layer = FlaxLayer(
module=flax_module,
method=my_flax_module_wrapper,
)
Expand Down
17 changes: 0 additions & 17 deletions keras/utils/jax_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,24 +321,7 @@ def verify_identical_model(model):
verify_identical_model(model3)

# export, load back and compare results
# TODO: fix and reenable this.
path = os.path.join(self.get_temp_dir(), "jax_layer_export")
# export_archive = export_lib.ExportArchive()
# export_archive.track(model2)
# export_archive.add_endpoint(
# "call",
# model2.call,
# input_signature=[
# tf.TensorSpec(
# shape=(None,) + input_shape,
# dtype=tf.float32,
# )
# ],
# )
# export_archive.write_out(
# path,
# tf.saved_model.SaveOptions(experimental_custom_gradients=False),
# )
export_lib.export_model(model2, path)
model4 = tf.saved_model.load(path)
output4 = model4.serve(x_test)
Expand Down

0 comments on commit a063684

Please sign in to comment.