Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support export to tf saved_model for those models with unused params #5694

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ def test_resnet18_save_load(self):
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))

def test_unused_param(self):

class M(torch.nn.Module):

def forward(self, a, b):
return torch.sin(b)

model = M()
data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100))
output = model(*data)

with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(model, data, tempdir)
loaded_m = tf.saved_model.load(tempdir)
res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0]
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))


if __name__ == '__main__':
test = unittest.main()
Expand Down
26 changes: 25 additions & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class StableHLOFunctionMeta:
# the arguments the user supplied, OR a parameter, OR a constant
input_locations: List[InputLocation]

unused_inputs: List[Tuple[InputLocation, VariableSignature]]

# input_pytree_spec
input_pytree_spec: Optional[str] = None
output_pytree_spec: Optional[str] = None
Expand Down Expand Up @@ -299,10 +301,16 @@ def _exported_program_to_stablehlo_bundle(exported_model,
if isinstance(tensor, torch.Tensor)
}

# there might be inputs that is part of input but not consumed by HLO graph
unused_input_positions = set(range(len(input_args)))

for hlo_input_pos, (tensor_id, tensor_value) in enumerate(
zip(graph_input_tensor_ids, graph_input_xla_values)):
if tensor_id in input_ids: # this is input
location = InputLocation.input_arg(position=input_ids[tensor_id])
pos_id = input_ids[tensor_id]
location = InputLocation.input_arg(position=pos_id)
if pos_id in unused_input_positions:
unused_input_positions.remove(pos_id)
elif tensor_id in tensor_id_to_state_name:
location = InputLocation.parameter(
name=tensor_id_to_state_name[tensor_id])
Expand All @@ -315,6 +323,21 @@ def _exported_program_to_stablehlo_bundle(exported_model,
shape=list(tensor_value.shape),
dtype=str(tensor_value.dtype).replace('torch.', '')))

unused_inputs = []
for i in unused_input_positions:
pos = InputLocation.input_arg(position=i)
arg = input_args[i]
if isinstance(arg, torch.Tensor):
signature = VariableSignature(
shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', ''))
else:
signature = VariableSignature(
shape=[],
dtype=str(type(arg)),
)

unused_inputs.append((pos, signature))

output_signature = [
VariableSignature(
shape=list(tensor.shape),
Expand All @@ -330,6 +353,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
unused_inputs=unused_inputs,
input_pytree_spec=pytree.treespec_dumps(exported_model.call_spec.in_spec),
output_pytree_spec=pytree.treespec_dumps(
exported_model.call_spec.out_spec),
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import sys
import os
from typing import List, Tuple, Any
Expand Down Expand Up @@ -44,7 +45,8 @@ def _make_input_signatures(
meta: stablehlo.StableHLOFunctionMeta) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in zip(meta.input_locations, meta.input_signature)
for loc, spec in itertools.chain(
zip(meta.input_locations, meta.input_signature), meta.unused_inputs)
if loc.type_ == stablehlo.VariableType.INPUT_ARG
}
for i in range(len(input_pos_to_spec)):
Expand Down