diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 1f65e4108bcc..188653abc3f0 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -81,7 +81,8 @@ def save_stablehlo_graph_as_tf( bundle = copy.deepcopy(stablehlo_program._bundle) tfm = tf.Module() bundle.state_dict = { - k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items() + k: tf.Variable(v, trainable=False, name=k) + for k, v in bundle.state_dict.items() } bundle.additional_constants = [ tf.Variable(v, trainable=False) for v in bundle.additional_constants