Skip to content

Commit

Permalink
Add support for unused params (#5694)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and bhavya01 committed Apr 22, 2024
1 parent 155e5b2 commit 0502f06
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
18 changes: 18 additions & 0 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,24 @@ def test_resnet18_save_load(self):
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))

def test_unused_param(self):

class M(torch.nn.Module):

def forward(self, a, b):
return torch.sin(b)

model = M()
data = (torch.randn(4, 3, 224, 224), torch.randn(1, 100))
output = model(*data)

with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(model, data, tempdir)
loaded_m = tf.saved_model.load(tempdir)
res = loaded_m.f(data[0].detach().numpy(), data[1].detach().numpy())[0]
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))


if __name__ == '__main__':
test = unittest.main()
Expand Down
26 changes: 25 additions & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class StableHLOFunctionMeta:
# the arguments the user supplied, OR a parameter, OR a constant
input_locations: List[InputLocation]

unused_inputs: List[Tuple[InputLocation, VariableSignature]]

# input_pytree_spec
input_pytree_spec: Optional[str] = None
output_pytree_spec: Optional[str] = None
Expand Down Expand Up @@ -299,10 +301,16 @@ def _exported_program_to_stablehlo_bundle(exported_model,
if isinstance(tensor, torch.Tensor)
}

# there might be inputs that is part of input but not consumed by HLO graph
unused_input_positions = set(range(len(input_args)))

for hlo_input_pos, (tensor_id, tensor_value) in enumerate(
zip(graph_input_tensor_ids, graph_input_xla_values)):
if tensor_id in input_ids: # this is input
location = InputLocation.input_arg(position=input_ids[tensor_id])
pos_id = input_ids[tensor_id]
location = InputLocation.input_arg(position=pos_id)
if pos_id in unused_input_positions:
unused_input_positions.remove(pos_id)
elif tensor_id in tensor_id_to_state_name:
location = InputLocation.parameter(
name=tensor_id_to_state_name[tensor_id])
Expand All @@ -315,6 +323,21 @@ def _exported_program_to_stablehlo_bundle(exported_model,
shape=list(tensor_value.shape),
dtype=str(tensor_value.dtype).replace('torch.', '')))

unused_inputs = []
for i in unused_input_positions:
pos = InputLocation.input_arg(position=i)
arg = input_args[i]
if isinstance(arg, torch.Tensor):
signature = VariableSignature(
shape=list(arg.shape), dtype=str(arg.dtype).replace('torch.', ''))
else:
signature = VariableSignature(
shape=[],
dtype=str(type(arg)),
)

unused_inputs.append((pos, signature))

output_signature = [
VariableSignature(
shape=list(tensor.shape),
Expand All @@ -330,6 +353,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
input_signature=input_signatures,
output_signature=output_signature,
input_locations=input_locations,
unused_inputs=unused_inputs,
input_pytree_spec=pytree.treespec_dumps(exported_model.call_spec.in_spec),
output_pytree_spec=pytree.treespec_dumps(
exported_model.call_spec.out_spec),
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import sys
import os
from typing import List, Tuple, Any
Expand Down Expand Up @@ -44,7 +45,8 @@ def _make_input_signatures(
meta: stablehlo.StableHLOFunctionMeta) -> List[tf.TensorSpec]:
input_pos_to_spec = {
loc.position: spec
for loc, spec in zip(meta.input_locations, meta.input_signature)
for loc, spec in itertools.chain(
zip(meta.input_locations, meta.input_signature), meta.unused_inputs)
if loc.type_ == stablehlo.VariableType.INPUT_ARG
}
for i in range(len(input_pos_to_spec)):
Expand Down

0 comments on commit 0502f06

Please sign in to comment.