diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index b66c0f3aafa..1f65e4108bc 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -36,9 +36,12 @@ def inner(*args): return inner -def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule): - return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], - stablehlo_program._bundle) +def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule, + bundle=None): + if bundle is None: + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], + stablehlo_program._bundle) + return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle) def _make_input_signatures( @@ -86,7 +89,8 @@ def save_stablehlo_graph_as_tf( input_signatures = list( _make_input_signatures(bundle.stablehlo_funcs[0].meta)) tfm.f = tf.function( - make_tf_function(stablehlo_program), input_signature=input_signatures) + make_tf_function(stablehlo_program, bundle), + input_signature=input_signatures) tfm._variables = ( list(bundle.state_dict.values()) + bundle.additional_constants) signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}