Skip to content

Commit

Permalink
Fix variables binding in JAX export. (keras-team#19399)
Browse files Browse the repository at this point in the history
In the current JAX export, while the variables are exported as `tf.Variable`s, these `tf.Variable`s are not at all connected to the graph. Instead, all the weights of the variables are inlined in the graph as constants.

As a consequence:
- The export is almost twice the size it should be
- One easily runs into the 2GB limit for the graph size with models with a lot of parameters
- Models that use RNGs in inference mode fail to export with a "leak" error.

This PR fixes the issues above by creating a stateless function, jax2tf converting it and binding the variables to the function in a `tf.function` wrapper.

Also:
- Fixed an issue where with the JAX export, the Python signature of the exported function was lost
  • Loading branch information
hertschuh authored Mar 29, 2024
1 parent c24b98a commit 4adb561
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
72 changes: 62 additions & 10 deletions keras/export/export_lib.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Library for exporting inference-only Keras models/layers."""

import inspect

from absl import logging

from keras import backend
from keras.api_export import keras_export
from keras.backend.common.stateless_scope import StatelessScope
from keras.layers import Layer
from keras.models import Functional
from keras.models import Sequential
Expand Down Expand Up @@ -89,6 +92,11 @@ def __init__(self):
self._tf_trackable.trainable_variables = []
self._tf_trackable.non_trainable_variables = []

if backend.backend() == "jax":
self._backend_variables = []
self._backend_trainable_variables = []
self._backend_non_trainable_variables = []

if backend.backend() not in ("tensorflow", "jax"):
raise NotImplementedError(
"The export API is only compatible with JAX and TF backends."
Expand Down Expand Up @@ -144,16 +152,24 @@ def track(self, resource):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
if backend.backend() == "jax":
self._tf_trackable.trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.trainable_variables
)

trainable_variables = tree.flatten(resource.trainable_variables)
non_trainable_variables = tree.flatten(
resource.non_trainable_variables
)
self._tf_trackable.non_trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.non_trainable_variables
)
self._backend_trainable_variables += trainable_variables
self._backend_non_trainable_variables += non_trainable_variables
self._backend_variables = (
self._backend_trainable_variables
+ self._backend_non_trainable_variables
)

self._tf_trackable.trainable_variables += [
tf.Variable(v) for v in trainable_variables
]
self._tf_trackable.non_trainable_variables += [
tf.Variable(v) for v in non_trainable_variables
]
self._tf_trackable.variables = (
self._tf_trackable.trainable_variables
+ self._tf_trackable.non_trainable_variables
Expand Down Expand Up @@ -281,9 +297,45 @@ def serving_fn(x):
if backend.backend() == "tensorflow":
decorated_fn = tf.function(fn, input_signature=input_signature)
else: # JAX backend
fn = self._convert_jax2tf_function(fn, input_signature)

# 1. Create a stateless wrapper for `fn`
# 2. jax2tf the stateless wrapper
# 3. Create a stateful function that binds the variables with
# the jax2tf converted stateless wrapper
# 4. Make the signature of the stateful function the same as the
# original function
# 5. Wrap in a `tf.function`
def stateless_fn(variables, *args, **kwargs):
state_mapping = zip(self._backend_variables, variables)
with StatelessScope(state_mapping=state_mapping):
return fn(*args, **kwargs)

variables_spec = [
_make_tensor_spec(v) for v in self._backend_variables
]

jax2tf_stateless_fn = self._convert_jax2tf_function(
stateless_fn, [variables_spec] + input_signature
)

def stateful_fn(*args, **kwargs):
return jax2tf_stateless_fn(
self._tf_trackable.variables, *args, **kwargs
)

# Note: we truncate the number of parameters to what is
# specified by `input_signature`.
fn_signature = inspect.signature(fn)
fn_parameters = list(fn_signature.parameters.values())
stateful_fn.__signature__ = inspect.Signature(
parameters=fn_parameters[0 : len(input_signature)],
return_annotation=fn_signature.return_annotation,
)

decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
stateful_fn,
input_signature=input_signature,
autograph=False,
)
self._endpoint_signatures[name] = input_signature
else:
Expand Down
26 changes: 20 additions & 6 deletions keras/export/export_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,22 @@
from keras import backend
from keras import layers
from keras import models
from keras import random
from keras import testing
from keras import utils
from keras.export import export_lib
from keras.saving import saving_lib


class RandomLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.seed_generator = backend.random.SeedGenerator()

def call(self, inputs):
return inputs + random.uniform(inputs.shape, seed=self.seed_generator)


def get_model():
layer_list = [
layers.Dense(10, activation="relu"),
Expand Down Expand Up @@ -42,6 +52,16 @@ def test_standard_model_export(self):
ref_output, revived_model.serve(ref_input), atol=1e-6
)

def test_model_with_rng_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = models.Sequential([RandomLayer()])
ref_input = tf.random.normal((3, 10))
ref_output = model(ref_input)

export_lib.export_model(model, temp_filepath)
revived_model = tf.saved_model.load(temp_filepath)
self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape)

def test_low_level_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

Expand Down Expand Up @@ -452,12 +472,6 @@ def test_non_standard_layer_signature(self):
atol=1e-6,
)

# TODO(nkovela): Remove test when argument name preservation
# workaround is created for JAX backend.
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="JAX2TF has issues with argument name preservation.",
)
def test_non_standard_layer_signature_with_kwargs(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")

Expand Down

0 comments on commit 4adb561

Please sign in to comment.