From ece0f819bf34f048031bb78e4c9cfac2bdebd4ca Mon Sep 17 00:00:00 2001 From: Han Qi Date: Tue, 10 Oct 2023 04:28:53 +0000 Subject: [PATCH] Add support for unused params --- test/stablehlo/test_saved_model.py | 18 +++++++++++++++++ torch_xla/stablehlo.py | 26 ++++++++++++++++++++++++- torch_xla/tf_saved_model_integration.py | 4 +++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py index 75c10c299b6..7c510ed3637 100644 --- a/test/stablehlo/test_saved_model.py +++ b/test/stablehlo/test_saved_model.py @@ -39,6 +39,24 @@ def test_resnet18_save_load(self): output2 = torch.tensor(res.numpy()) self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + def test_unused_param(self): + + class M(torch.nn.Module): + + def forward(self, a, b): + return torch.sin(b) + + model = M() + data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100)) + output = model(*data) + + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model(model, data, tempdir) + loaded_m = tf.saved_model.load(tempdir) + res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0] + output2 = torch.tensor(res.numpy()) + self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 35b6b6cdf51..557f1271611 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -135,6 +135,8 @@ class StableHLOFunctionMeta: # the arguments the user supplied, OR a parameter, OR a constant input_locations: List[InputLocation] + unused_inputs: List[Tuple[InputLocation, VariableSignature]] + # input_pytree_spec input_pytree_spec: Optional[str] = None output_pytree_spec: Optional[str] = None @@ -299,10 +301,16 @@ def _exported_program_to_stablehlo_bundle(exported_model, if isinstance(tensor, torch.Tensor) } + # there might be inputs that is part of input but not consumed by HLO graph + unused_input_positions = set(range(len(input_args))) + for hlo_input_pos, (tensor_id, tensor_value) in enumerate( zip(graph_input_tensor_ids, graph_input_xla_values)): if tensor_id in input_ids: # this is input - location = InputLocation.input_arg(position=input_ids[tensor_id]) + pos_id = input_ids[tensor_id] + location = InputLocation.input_arg(position=pos_id) + if pos_id in unused_input_positions: + unused_input_positions.remove(pos_id) elif tensor_id in tensor_id_to_state_name: location = InputLocation.parameter( name=tensor_id_to_state_name[tensor_id]) @@ -315,6 +323,21 @@ def _exported_program_to_stablehlo_bundle(exported_model, shape=list(tensor_value.shape), dtype=str(tensor_value.dtype).replace('torch.', ''))) + unused_inputs = [] + for i in unused_input_positions: + pos = InputLocation.input_arg(position=i) + arg = input_args[i] + if isinstance(arg, torch.Tensor): + signature = VariableSignature( + shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', '')) + else: + signature = VariableSignature( + shape=[], + dtype=str(type(arg)), + ) + + unused_inputs.append((pos, signature)) + output_signature = [ VariableSignature( shape=list(tensor.shape), @@ -330,6 +353,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, input_signature=input_signatures, output_signature=output_signature, input_locations=input_locations, + unused_inputs=unused_inputs, input_pytree_spec=pytree.treespec_dumps(exported_model.call_spec.in_spec), output_pytree_spec=pytree.treespec_dumps( exported_model.call_spec.out_spec), diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 2c40a21b7b5..b66c0f3aafa 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -1,3 +1,4 @@ +import itertools import sys import os from typing import List, Tuple, Any @@ -44,7 +45,8 @@ def _make_input_signatures( meta: stablehlo.StableHLOFunctionMeta) -> List[tf.TensorSpec]: input_pos_to_spec = { loc.position: spec - for loc, spec in zip(meta.input_locations, meta.input_signature) + for loc, spec in itertools.chain( + zip(meta.input_locations, meta.input_signature), meta.unused_inputs) if loc.type_ == stablehlo.VariableType.INPUT_ARG } for i in range(len(input_pos_to_spec)):