Skip to content

Commit

Permalink
When exporting StableHLO to SavedModel, also include the original var (
Browse files Browse the repository at this point in the history
…#5711)

* Pass bundle object to make_tf_function, so that it will include tf.variables as args to the XlaCallModule op.

* fix format

* When exporting StableHLO to SavedModel, also include the original variable tensor name from pytorch.
  • Loading branch information
haozha111 authored and golechwierowicz committed Jan 12, 2024
1 parent de69894 commit 1095724
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def save_stablehlo_graph_as_tf(
bundle = copy.deepcopy(stablehlo_program._bundle)
tfm = tf.Module()
bundle.state_dict = {
k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
k: tf.Variable(v, trainable=False, name=k)
for k, v in bundle.state_dict.items()
}
bundle.additional_constants = [
tf.Variable(v, trainable=False) for v in bundle.additional_constants
Expand Down

0 comments on commit 1095724

Please sign in to comment.