diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 188653abc3f..511b9e02e9b 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -58,6 +58,19 @@ def _make_input_signatures( shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') +def _mangle_tf_root_scope_name(name): + # TF has more restricted constrain on the variable names at root scope. + # Root scope name constrain: [A-Za-z0-9.][A-Za-z0-9_.\\-/]* + # Non-root scope name constrain: [A-Za-z0-9_.\\-/]* + # https://github.com/tensorflow/tensorflow/blob/51b601fa6bb7e801c0b6ae73c25580e40a8b5745/tensorflow/python/framework/ops.py#L3301-L3302 + # The state_dict key doesn't have such constrain, + # the name need to be mangled when a root-scoped TF variable is created. + if name[0] in "._\\-/": + return 'k' + name + else: + return name + + def save_stablehlo_graph_as_tf( stablehlo_program: stablehlo.StableHLOGraphModule, path: os.PathLike, @@ -81,7 +94,7 @@ def save_stablehlo_graph_as_tf( bundle = copy.deepcopy(stablehlo_program._bundle) tfm = tf.Module() bundle.state_dict = { - k: tf.Variable(v, trainable=False, name=k) + k: tf.Variable(v, trainable=False, name=_mangle_tf_root_scope_name(k)) for k, v in bundle.state_dict.items() } bundle.additional_constants = [