From efd5e58141b254d10d025e07eaf42d97e0f510b9 Mon Sep 17 00:00:00 2001 From: qihqi Date: Thu, 14 Dec 2023 02:24:43 -0800 Subject: [PATCH] Support forwarding dynamic shape to tf saved model (#6149) --- test/stablehlo/test_saved_model.py | 53 +++++++++++++++---------- torch_xla/stablehlo.py | 22 +++++++--- torch_xla/tf_saved_model_integration.py | 12 +++++- 3 files changed, 59 insertions(+), 28 deletions(-) diff --git a/test/stablehlo/test_saved_model.py b/test/stablehlo/test_saved_model.py index 7c510ed3637b..caa0b47a7ebe 100644 --- a/test/stablehlo/test_saved_model.py +++ b/test/stablehlo/test_saved_model.py @@ -1,10 +1,13 @@ +import numpy as np import torch_xla import torch_xla.core.xla_model as xm from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo -from torch_xla.tf_saved_model_integration import make_tf_function, save_torch_module_as_tf_saved_model +from torch_xla.tf_saved_model_integration import ( + make_tf_function, save_torch_module_as_tf_saved_model, + save_stablehlo_graph_as_tf) from torch.utils import _pytree as pytree +from torch.export import export, dynamic_dim import torch -import torchvision import tempfile import unittest @@ -13,31 +16,39 @@ class StableHLOInferenceTest(unittest.TestCase): - def test_resnet18_inference(self): - resnet18 = torchvision.models.resnet18().eval() - data = torch.randn(4, 3, 224, 224) - output = resnet18(data) + def test_dynamic_shapes(self): - exported = torch.export.export(resnet18, (data,)) - options = StableHLOExportOptions(override_tracing_arguments=(data,)) - stablehlo_program = exported_program_to_stablehlo(exported, options) - tf_func = make_tf_function(stablehlo_program) + class MyModule(torch.nn.Module): - output_tf = tf_func(*options.override_tracing_arguments) - output2 = torch.tensor(output_tf[0].numpy()) - self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + def forward(self, a, b): + return a * b - def test_resnet18_save_load(self): - resnet18 = torchvision.models.resnet18().eval() - data = torch.randn(4, 3, 224, 224) - output = resnet18(data) + model = MyModule() + a = torch.randn(3, 10) + b = torch.randn(3, 10) + constraints = [ + dynamic_dim(a, 0), + dynamic_dim(b, 0), + dynamic_dim(a, 0) == dynamic_dim(b, 0) + ] + exported = torch.export.export( + model, ( + a, + b, + ), constraints=constraints) + shlo = exported_program_to_stablehlo(exported) + print(shlo.get_stablehlo_text()) with tempfile.TemporaryDirectory() as tempdir: - save_torch_module_as_tf_saved_model(resnet18, (data,), tempdir) + save_stablehlo_graph_as_tf( + shlo, + tempdir, + serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias='') loaded_m = tf.saved_model.load(tempdir) - res = loaded_m.f(data.detach().numpy())[0] - output2 = torch.tensor(res.numpy()) - self.assertTrue(torch.allclose(output, output2, atol=1e-5)) + data2 = (np.random.randn(2, 10), np.random.randn(2, 10)) + res = loaded_m.f(*data2) + self.assertTrue(np.allclose(res, data2[0] * data2[1])) def test_unused_param(self): diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 52cab416d234..7eb655e828e3 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -116,6 +116,7 @@ class VariableType(enum.Enum): class VariableSignature: # either argument or parameters shape: List[int] dtype: str + dynamic_dims: List[int] = dataclasses.field(default_factory=list) @dataclass @@ -215,6 +216,13 @@ class XLAExportInterpreter(torch.fx.Interpreter): def __init__(self, module, device): self._device = device super().__init__(module) + self.tensor_id_to_dynamic_dims = {} + + def _mark_dynamic(self, tensor, dynamic_dims): + tid = torch_xla._XLAC._xla_get_tensor_id(tensor) + self.tensor_id_to_dynamic_dims[tid] = dynamic_dims + for i in dynamic_dims: + torch_xla._XLAC._xla_mark_dynamic(tensor, i) def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: # NOTE(qihqi): We need to do this because there are some operators @@ -232,9 +240,10 @@ def run_node(self, n) -> Any: fake_t = n.meta['val'] res = super().run_node(n) if hasattr(fake_t, 'shape'): - for i, x in enumerate(fake_t.shape): - if not isinstance(x, int): - torch_xla._XLAC._xla_mark_dynamic(res, i) + dynamic_dims = [ + i for i, x in enumerate(fake_t.shape) if not isinstance(x, int) + ] + self._mark_dynamic(res, dynamic_dims) return res return super().run_node(n) @@ -297,8 +306,9 @@ def _exported_program_to_stablehlo_bundle(exported_model, device = xm.xla_device() # Run the fx graph tracing using lazy tensor + xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device) with torch.no_grad(): - res = XLAExportInterpreter(exported_model.graph_module, device).run( + res = xla_interpreter.run( *param_buffer_values, *input_args, *ordered_tensor_constants, @@ -357,10 +367,12 @@ def _exported_program_to_stablehlo_bundle(exported_model, location = InputLocation.constant(position=len(additional_constants)) additional_constants.append(tensor_value.detach().cpu().numpy()) input_locations.append(location) + dynamic_dims = xla_interpreter.tensor_id_to_dynamic_dims.get(tensor_id, []) input_signatures.append( VariableSignature( shape=list(tensor_value.shape), - dtype=str(tensor_value.dtype).replace('torch.', ''))) + dtype=str(tensor_value.dtype).replace('torch.', ''), + dynamic_dims=dynamic_dims)) unused_inputs = [] for i in unused_input_positions: diff --git a/torch_xla/tf_saved_model_integration.py b/torch_xla/tf_saved_model_integration.py index 511b9e02e9b2..4737c7cf9ff2 100644 --- a/torch_xla/tf_saved_model_integration.py +++ b/torch_xla/tf_saved_model_integration.py @@ -17,12 +17,19 @@ raise +def _get_shape_with_dynamic(signature: stablehlo.VariableSignature): + shape = copy.copy(signature.shape) + for i in signature.dynamic_dims: + shape[i] = None + return shape + + def _wrap_as_tf_func(func, bundle): def inner(*args): output_sig = func.meta.output_signature[0] Touts = [sig.dtype for sig in func.meta.output_signature] - Souts = [sig.shape for sig in func.meta.output_signature] + Souts = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature] call_args = stablehlo._extract_call_parameters(args, func.meta, bundle) return tfxla.call_module( tuple(call_args), @@ -54,8 +61,9 @@ def _make_input_signatures( } for i in range(len(input_pos_to_spec)): spec = input_pos_to_spec[i] + shape = _get_shape_with_dynamic(spec) yield tf.TensorSpec( - shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') + shape=shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}') def _mangle_tf_root_scope_name(name):