From 8ffd7bd1eb3a19f8ce70670addb50d30db8feeea Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Mon, 16 Oct 2023 14:34:00 -0700 Subject: [PATCH] pass bundle object to make_tf_function (#5708) * Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op. * fix format --- torch_xla/tf_saved_model_integration.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index b66c0f3aafa..1f65e4108bc 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,9 +36,12 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): - return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], - stablehlo_program._bundle) +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, + bundle=None): + if bundle is None: + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle) def _make_input_signatures( @@ -86,7 +89,8 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), + input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}