Skip to content

Commit

Permalink
Mangle root scope TF variable name during tf.saved_model export (#5738)
Browse files Browse the repository at this point in the history
* mangle tf root scope name
  • Loading branch information
lsy323 authored and bhavya01 committed Apr 22, 2024
1 parent f228a65 commit 48c853c
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ def _make_input_signatures(
shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}')


def _mangle_tf_root_scope_name(name):
# TF has more restricted constrain on the variable names at root scope.
# Root scope name constrain: [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
# Non-root scope name constrain: [A-Za-z0-9_.\\-/]*
# https://github.com/tensorflow/tensorflow/blob/51b601fa6bb7e801c0b6ae73c25580e40a8b5745/tensorflow/python/framework/ops.py#L3301-L3302
# The state_dict key doesn't have such constrain,
# the name need to be mangled when a root-scoped TF variable is created.
if name[0] in "._\\-/":
return 'k' + name
else:
return name


def save_stablehlo_graph_as_tf(
stablehlo_program: stablehlo.StableHLOGraphModule,
path: os.PathLike,
Expand All @@ -81,7 +94,7 @@ def save_stablehlo_graph_as_tf(
bundle = copy.deepcopy(stablehlo_program._bundle)
tfm = tf.Module()
bundle.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
k: tf.Variable(v, trainable=False, name=_mangle_tf_root_scope_name(k))
for k, v in bundle.state_dict.items()
}
bundle.additional_constants = [
Expand Down

0 comments on commit 48c853c

Please sign in to comment.