Skip to content

Commit

Permalink
pass bundle object to make_tf_function (pytorch#5708)
Browse files Browse the repository at this point in the history
* Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op.

* fix format
  • Loading branch information
haozha111 authored and mbzomowski committed Nov 16, 2023
1 parent 919b348 commit 8ffd7bd
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ def inner(*args):
return inner


def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule):
return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0],
stablehlo_program._bundle)
def make_tf_function(stablehlo_program: stablehlo.StableHLOGraphModule,
bundle=None):
if bundle is None:
return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0],
stablehlo_program._bundle)
return _wrap_as_tf_func(stablehlo_program._bundle.stablehlo_funcs[0], bundle)


def _make_input_signatures(
Expand Down Expand Up @@ -86,7 +89,8 @@ def save_stablehlo_graph_as_tf(
input_signatures = list(
_make_input_signatures(bundle.stablehlo_funcs[0].meta))
tfm.f = tf.function(
make_tf_function(stablehlo_program), input_signature=input_signatures)
make_tf_function(stablehlo_program, bundle),
input_signature=input_signatures)
tfm._variables = (
list(bundle.state_dict.values()) + bundle.additional_constants)
signatures = {serving_key: tfm.f.get_concrete_function(*input_signatures)}
Expand Down

0 comments on commit 8ffd7bd

Please sign in to comment.