From 2df15aff3aeb51bde0fd07c5a7d36ea45ba7ab45 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Mon, 16 Oct 2023 12:04:17 -0700 Subject: [PATCH 1/3] Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op. --- torch_xla/tf_saved_model_integration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index b66c0f3aafa..e642aa9e664 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,9 +36,11 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): - return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], - stablehlo_program._bundle) +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, bundle=None): + if bundle is None: + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle) def _make_input_signatures( @@ -86,7 +88,7 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} From 2053b11e906148c44173e863035ce5f9c35b823f Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Mon, 16 Oct 2023 14:29:24 -0700 Subject: [PATCH 2/3] fix format --- torch_xla/tf_saved_model_integration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index e642aa9e664..1f65e4108bc 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,7 +36,8 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, bundle=None): +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, + bundle=None): if bundle is None: return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], stablehlo_program._bundle) @@ -88,7 +89,8 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program, bundle), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), + input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} From 9586351a157a15c1a73022fe54d03365ecbe7b2a Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Wed, 18 Oct 2023 14:59:03 -0700 Subject: [PATCH 3/3] When exporting StableHLO to SavedModel, also include the original variable tensor name from pytorch. --- torch_xla/tf_saved_model_integration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 1f65e4108bc..188653abc3f 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -81,7 +81,8 @@ def save_stablehlo_graph_as_tf( bundle = copy.deepcopy(stablehlo_program._bundle) tfm = tf.Module() bundle.state_dict = { - k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items() + k: tf.Variable(v, trainable=False, name=k) + for k, v in bundle.state_dict.items() } bundle.additional_constants = [ tf.Variable(v, trainable=False) for v in bundle.additional_constants