From 4adb561da53d97fc994d9fdc719e7b953ad303ce Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Thu, 28 Mar 2024 18:23:01 -0700 Subject: [PATCH] Fix variables binding in JAX export. (#19399) In the current JAX export, while the variables are exported as `tf.Variable`s, these `tf.Variable`s are not at all connected to the graph. Instead, all the weights of the variables are inlined in the graph as constants. As a consequence: - The export is almost twice the size it should be - One easily runs into the 2GB limit for the graph size with models with a lot of parameters - Models that use RNGs in inference mode fail to export with a "leak" error. This PR fixes the issues above by creating a stateless function, jax2tf converting it and binding the variables to the function in a `tf.function` wrapper. Also: - Fixed an issue where with the JAX export, the Python signature of the exported function was lost --- keras/export/export_lib.py | 72 ++++++++++++++++++++++++++++----- keras/export/export_lib_test.py | 26 +++++++++--- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index efe8b05342a..e0ca00e7d85 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -1,9 +1,12 @@ """Library for exporting inference-only Keras models/layers.""" +import inspect + from absl import logging from keras import backend from keras.api_export import keras_export +from keras.backend.common.stateless_scope import StatelessScope from keras.layers import Layer from keras.models import Functional from keras.models import Sequential @@ -89,6 +92,11 @@ def __init__(self): self._tf_trackable.trainable_variables = [] self._tf_trackable.non_trainable_variables = [] + if backend.backend() == "jax": + self._backend_variables = [] + self._backend_trainable_variables = [] + self._backend_non_trainable_variables = [] + if backend.backend() not in ("tensorflow", "jax"): raise NotImplementedError( "The export API is only compatible with JAX and TF backends." @@ -144,16 +152,24 @@ def track(self, resource): # Variables in the lists below are actually part of the trackables # that get saved, because the lists are created in __init__. if backend.backend() == "jax": - self._tf_trackable.trainable_variables += tree.flatten( - tree.map_structure( - tf.Variable, resource.trainable_variables - ) + + trainable_variables = tree.flatten(resource.trainable_variables) + non_trainable_variables = tree.flatten( + resource.non_trainable_variables ) - self._tf_trackable.non_trainable_variables += tree.flatten( - tree.map_structure( - tf.Variable, resource.non_trainable_variables - ) + self._backend_trainable_variables += trainable_variables + self._backend_non_trainable_variables += non_trainable_variables + self._backend_variables = ( + self._backend_trainable_variables + + self._backend_non_trainable_variables ) + + self._tf_trackable.trainable_variables += [ + tf.Variable(v) for v in trainable_variables + ] + self._tf_trackable.non_trainable_variables += [ + tf.Variable(v) for v in non_trainable_variables + ] self._tf_trackable.variables = ( self._tf_trackable.trainable_variables + self._tf_trackable.non_trainable_variables @@ -281,9 +297,45 @@ def serving_fn(x): if backend.backend() == "tensorflow": decorated_fn = tf.function(fn, input_signature=input_signature) else: # JAX backend - fn = self._convert_jax2tf_function(fn, input_signature) + + # 1. Create a stateless wrapper for `fn` + # 2. jax2tf the stateless wrapper + # 3. Create a stateful function that binds the variables with + # the jax2tf converted stateless wrapper + # 4. Make the signature of the stateful function the same as the + # original function + # 5. Wrap in a `tf.function` + def stateless_fn(variables, *args, **kwargs): + state_mapping = zip(self._backend_variables, variables) + with StatelessScope(state_mapping=state_mapping): + return fn(*args, **kwargs) + + variables_spec = [ + _make_tensor_spec(v) for v in self._backend_variables + ] + + jax2tf_stateless_fn = self._convert_jax2tf_function( + stateless_fn, [variables_spec] + input_signature + ) + + def stateful_fn(*args, **kwargs): + return jax2tf_stateless_fn( + self._tf_trackable.variables, *args, **kwargs + ) + + # Note: we truncate the number of parameters to what is + # specified by `input_signature`. + fn_signature = inspect.signature(fn) + fn_parameters = list(fn_signature.parameters.values()) + stateful_fn.__signature__ = inspect.Signature( + parameters=fn_parameters[0 : len(input_signature)], + return_annotation=fn_signature.return_annotation, + ) + decorated_fn = tf.function( - fn, input_signature=input_signature, autograph=False + stateful_fn, + input_signature=input_signature, + autograph=False, ) self._endpoint_signatures[name] = input_signature else: diff --git a/keras/export/export_lib_test.py b/keras/export/export_lib_test.py index bec3cab8c19..5cb7134ef93 100644 --- a/keras/export/export_lib_test.py +++ b/keras/export/export_lib_test.py @@ -9,12 +9,22 @@ from keras import backend from keras import layers from keras import models +from keras import random from keras import testing from keras import utils from keras.export import export_lib from keras.saving import saving_lib +class RandomLayer(layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.seed_generator = backend.random.SeedGenerator() + + def call(self, inputs): + return inputs + random.uniform(inputs.shape, seed=self.seed_generator) + + def get_model(): layer_list = [ layers.Dense(10, activation="relu"), @@ -42,6 +52,16 @@ def test_standard_model_export(self): ref_output, revived_model.serve(ref_input), atol=1e-6 ) + def test_model_with_rng_export(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = models.Sequential([RandomLayer()]) + ref_input = tf.random.normal((3, 10)) + ref_output = model(ref_input) + + export_lib.export_model(model, temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) + def test_low_level_model_export(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -452,12 +472,6 @@ def test_non_standard_layer_signature(self): atol=1e-6, ) - # TODO(nkovela): Remove test when argument name preservation - # workaround is created for JAX backend. - @pytest.mark.skipif( - backend.backend() != "tensorflow", - reason="JAX2TF has issues with argument name preservation.", - ) def test_non_standard_layer_signature_with_kwargs(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")