From e7af313ddd8f9630ddff8dc69589ba13735c5f61 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Sat, 21 Oct 2023 17:50:07 -0700 Subject: [PATCH] When exporting StableHLO to SavedModel, also include the original var (#5711) * Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op. * fix format * When exporting StableHLO to SavedModel, also include the original variable tensor name from pytorch. --- torch_xla/tf_saved_model_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 1f65e4108bc..188653abc3f 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -81,7 +81,8 @@ def save_stablehlo_graph_as_tf( bundle = copy.deepcopy(stablehlo_program._bundle) tfm = tf.Module() bundle.state_dict = { - k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items() + k: tf.Variable(v, trainable=False, name=k) + for k, v in bundle.state_dict.items() } bundle.additional_constants = [ tf.Variable(v, trainable=False) for v in bundle.additional_constants