Skip to content

Commit

Permalink
Add support for exporting savedmodel in the torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Dec 24, 2024
1 parent 3dd958b commit aa2c079
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 75 deletions.
165 changes: 141 additions & 24 deletions keras/src/backend/torch/export.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,152 @@
from keras.src import layers
import copy
import re
import warnings

import torch

from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.utils.module_utils import tensorflow as tf
from keras.src.utils.module_utils import torch_xla


class TorchExportArchive:
def track(self, resource):
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"JAX-based Keras `Layer` or `Model`. "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)
raise NotImplementedError(
"`track` is not implemented in the torch backend. Use"
"`track_and_add_endpoint` instead."
)

def add_endpoint(self, name, fn, input_signature, **kwargs):
raise NotImplementedError(
"`add_endpoint` is not implemented in the torch backend. Use"
"`track_and_add_endpoint` instead."
)

def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
# Disable false alarms related to lifting parameters.
warnings.filterwarnings("ignore", message=".*created when tracing.*")
warnings.filterwarnings(
"ignore", message=".*Unable to find the path of the module.*"
)

if isinstance(resource, layers.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
variables = resource.variables
trainable_variables = resource.trainable_variables
non_trainable_variables = resource.non_trainable_variables
self._tf_trackable.variables += tree.map_structure(
self._convert_to_tf_variable, variables
if not isinstance(resource, torch.nn.Module):
raise TypeError(
"`resource` must be an instance of `torch.nn.Module`. "
f"Received: resource={resource} (of type {type(resource)})"
)
self._tf_trackable.trainable_variables += tree.map_structure(
self._convert_to_tf_variable, trainable_variables

def _check_input_signature(input_spec):
for s in tree.flatten(input_spec.shape):
if s is None:
raise ValueError(
"The shape in the `input_spec` must be fully "
f"specified. Received: input_spec={input_spec}"
)

def _to_torch_tensor(x, replace_none_number=1):
shape = backend.standardize_shape(x.shape)
shape = tuple(
s if s is not None else replace_none_number for s in shape
)
self._tf_trackable.non_trainable_variables += tree.map_structure(
self._convert_to_tf_variable, non_trainable_variables
return ops.ones(shape, x.dtype)

tree.map_structure(_check_input_signature, input_signature)
sample_inputs = tree.map_structure(_to_torch_tensor, input_signature)
sample_inputs = tuple(sample_inputs)

# Ref: torch_xla.tf_saved_model_integration
# TODO: Utilize `dynamic_shapes`
exported = torch.export.export(
resource, sample_inputs, dynamic_shapes=None, strict=False
)
options = torch_xla.stablehlo.StableHLOExportOptions(
override_tracing_arguments=sample_inputs
)
stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo(
exported, options
)
state_dict_keys = list(stablehlo_model._bundle.state_dict.keys())

# Remove unused variables.
for k in state_dict_keys:
if "lifted" not in k:
stablehlo_model._bundle.state_dict.pop(k)

def _mangle_tf_root_scope_name(name):
# TF has more restricted constrain on the variable names.
if name[0] in "._\\-/":
name = "k" + name
name = re.sub(r"[^^\w\-/\\]+", "_", name)
return name

bundle = copy.deepcopy(stablehlo_model._bundle)
bundle.state_dict = {
k: tf.Variable(
v, trainable=False, name=_mangle_tf_root_scope_name(k)
)
for k, v in bundle.state_dict.items()
}
bundle.additional_constants = [
tf.Variable(v, trainable=False) for v in bundle.additional_constants
]

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
# TODO: torch-xla?
raise NotImplementedError(
"`add_endpoint` is not implemented in the torch backend."
# Track variables in `bundle` for `write_out`.
self._tf_trackable.variables += (
list(bundle.state_dict.values()) + bundle.additional_constants
)

# Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf
def make_tf_function(func, bundle):
from tensorflow.compiler.tf2xla.python import xla as tfxla

def _get_shape_with_dynamic(signature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape

def _extract_call_parameters(args, meta, bundle):
call_args = []
if meta.input_pytree_spec is not None:
args = tree.flatten(args)
for loc in meta.input_locations:
if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER:
call_args.append(bundle.state_dict[loc.name])
elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT:
call_args.append(
bundle.additional_constants[loc.position]
)
else:
call_args.append(args[loc.position])
return call_args

def inner(*args):
Touts = [sig.dtype for sig in func.meta.output_signature]
Souts = [
_get_shape_with_dynamic(sig)
for sig in func.meta.output_signature
]
call_args = _extract_call_parameters(args, func.meta, bundle)
results = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=func.bytecode,
)
if len(Souts) == 1:
results = results[0]
return results

return inner

decorated_fn = tf.function(
make_tf_function(
stablehlo_model._bundle.stablehlo_funcs[0], bundle
),
input_signature=input_signature,
)
return decorated_fn
127 changes: 100 additions & 27 deletions keras/src/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ExportArchive(BackendExportArchive):
**Note on resource tracking:**
`ExportArchive` is able to automatically track all `tf.Variables` used
`ExportArchive` is able to automatically track all `keras.Variables` used
by its endpoints, so most of the time calling `.track(model)`
is not strictly required. However, if your model uses lookup layers such
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
Expand All @@ -104,9 +104,10 @@ class ExportArchive(BackendExportArchive):

def __init__(self):
super().__init__()
if backend.backend() not in ("tensorflow", "jax"):
if backend.backend() not in ("tensorflow", "jax", "torch"):
raise NotImplementedError(
"The export API is only compatible with JAX and TF backends."
"`ExportArchive` is only compatible with TensorFlow, JAX and "
"Torch backends."
)

self._endpoint_names = []
Expand Down Expand Up @@ -141,8 +142,8 @@ def track(self, resource):
(`TextVectorization`, `IntegerLookup`, `StringLookup`)
are automatically tracked in `add_endpoint()`.
Arguments:
resource: A trackable TensorFlow resource.
Args:
resource: A trackable Keras resource, such as a layer or model.
"""
if isinstance(resource, layers.Layer) and not resource.built:
raise ValueError(
Expand Down Expand Up @@ -334,12 +335,78 @@ def serving_fn(x):
self._endpoint_names.append(name)
return decorated_fn

def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
"""Track the variables and register a new serving endpoint.
This function combines the functionality of `track` and `add_endpoint`.
It tracks the variables of the `resource` (either a layer or a model)
and registers a serving endpoint using `resource.__call__`.
Args:
name: `str`. The name of the endpoint.
resource: A trackable Keras resource, such as a layer or model.
input_signature: Optional. Specifies the shape and dtype of `fn`.
Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
`backend.KerasTensor`, or backend tensor (see below for an
example showing a `Functional` model with 2 input arguments). If
not provided, `fn` must be a `tf.function` that has been called
at least once. Defaults to `None`.
**kwargs: Additional keyword arguments:
- Specific to the JAX backend:
- `is_static`: Optional `bool`. Indicates whether `fn` is
static. Set to `False` if `fn` involves state updates
(e.g., RNG seeds).
- `jax2tf_kwargs`: Optional `dict`. Arguments for
`jax2tf.convert`. See [`jax2tf.convert`](
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
If `native_serialization` and `polymorphic_shapes` are
not provided, they are automatically computed.
"""
if name in self._endpoint_names:
raise ValueError(f"Endpoint name '{name}' is already taken.")
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a Keras "
"`Layer` or `Model`. "
f"Received: resource={resource} (of type {type(resource)})"
)
if not resource.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
if backend.backend() != "jax":
if "jax2tf_kwargs" in kwargs or "is_static" in kwargs:
raise ValueError(
"'jax2tf_kwargs' and 'is_static' are only supported with "
f"the jax backend. Current backend: {backend.backend()}"
)

input_signature = tree.map_structure(_make_tensor_spec, input_signature)

if not hasattr(BackendExportArchive, "track_and_add_endpoint"):
# Default behavior.
self.track(resource)
return self.add_endpoint(
name, resource.__call__, input_signature, **kwargs
)
else:
# Special case for the torch backend.
decorated_fn = BackendExportArchive.track_and_add_endpoint(
self, name, resource, input_signature, **kwargs
)
self._endpoint_signatures[name] = input_signature
setattr(self._tf_trackable, name, decorated_fn)
self._endpoint_names.append(name)
return decorated_fn

def add_variable_collection(self, name, variables):
"""Register a set of variables to be retrieved after reloading.
Arguments:
name: The string name for the collection.
variables: A tuple/list/set of `tf.Variable` instances.
variables: A tuple/list/set of `keras.Variable` instances.
Example:
Expand Down Expand Up @@ -496,9 +563,6 @@ def export_saved_model(
):
"""Export the model as a TensorFlow SavedModel artifact for inference.
**Note:** This feature is currently supported only with TensorFlow and
JAX backends.
This method lets you export a model to a lightweight SavedModel artifact
that contains the model's forward pass only (its `call()` method)
and can be served via e.g. TensorFlow Serving. The forward pass is
Expand Down Expand Up @@ -527,6 +591,14 @@ def export_saved_model(
If `native_serialization` and `polymorphic_shapes` are not
provided, they are automatically computed.
**Note:** This feature is currently supported only with TensorFlow, JAX and
Torch backends. Support for the Torch backend is experimental.
**Note:** The dynamic shape feature is not yet supported with Torch
backend. As a result, you must fully define the shapes of the inputs using
`input_signature`. If `input_signature` is not provided, all instances of
`None` (such as the batch size) will be replaced with `1`.
Example:
```python
Expand All @@ -543,28 +615,29 @@ def export_saved_model(
`export()` method relies on `ExportArchive` internally.
"""
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
if input_signature is None:
if input_signature is None:
if not model.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
if isinstance(model, (Functional, Sequential)):
input_signature = tree.map_structure(
_make_tensor_spec, model.inputs
)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint(
"serve", model.__call__, input_signature, **kwargs
)
else:
if input_signature is None:
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
else:
input_signature = _get_input_signature(model)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
export_archive.add_endpoint(
"serve", model.__call__, input_signature, **kwargs
)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)

export_archive.track_and_add_endpoint(
"serve", model, input_signature, **kwargs
)
export_archive.write_out(filepath, verbose=verbose)


Expand Down
Loading

0 comments on commit aa2c079

Please sign in to comment.