From 2df15aff3aeb51bde0fd07c5a7d36ea45ba7ab45 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Mon, 16 Oct 2023 12:04:17 -0700 Subject: [PATCH 1/2] Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op. --- torch_xla/tf_saved_model_integration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index b66c0f3aafa..e642aa9e664 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,9 +36,11 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): - return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], - stablehlo_program._bundle) +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, bundle=None): + if bundle is None: + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle) def _make_input_signatures( @@ -86,7 +88,7 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)} From 2053b11e906148c44173e863035ce5f9c35b823f Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Mon, 16 Oct 2023 14:29:24 -0700 Subject: [PATCH 2/2] fix format --- torch_xla/tf_saved_model_integration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index e642aa9e664..1f65e4108bc 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,7 +36,8 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, bundle=None): +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, + bundle=None): if bundle is None: return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], stablehlo_program._bundle) @@ -88,7 +89,8 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program, bundle), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), + input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}