diff --git a/keras/utils/jax_layer.py b/keras/utils/jax_layer.py index c245600d951..746588e4865 100644 --- a/keras/utils/jax_layer.py +++ b/keras/utils/jax_layer.py @@ -488,7 +488,7 @@ def __call__(self, inputs): ```python class MyFlaxModule(flax.linen.Module): @flax.linen.compact - def forward(self, input1, input1, deterministic): + def forward(self, input1, input2, deterministic): ... return outputs @@ -497,7 +497,7 @@ def my_flax_module_wrapper(module, inputs, training): return module.forward(input1, input2, not training) flax_module = MyFlaxModule() - keras_layer = FlaxLayer(flax_module) + keras_layer = FlaxLayer( module=flax_module, method=my_flax_module_wrapper, ) diff --git a/keras/utils/jax_layer_test.py b/keras/utils/jax_layer_test.py index 5b83a9750ad..404e843afb5 100644 --- a/keras/utils/jax_layer_test.py +++ b/keras/utils/jax_layer_test.py @@ -321,24 +321,7 @@ def verify_identical_model(model): verify_identical_model(model3) # export, load back and compare results - # TODO: fix and reenable this. path = os.path.join(self.get_temp_dir(), "jax_layer_export") - # export_archive = export_lib.ExportArchive() - # export_archive.track(model2) - # export_archive.add_endpoint( - # "call", - # model2.call, - # input_signature=[ - # tf.TensorSpec( - # shape=(None,) + input_shape, - # dtype=tf.float32, - # ) - # ], - # ) - # export_archive.write_out( - # path, - # tf.saved_model.SaveOptions(experimental_custom_gradients=False), - # ) export_lib.export_model(model2, path) model4 = tf.saved_model.load(path) output4 = model4.serve(x_test)