Skip to content

Commit

Permalink
Support forwarding dynamic shape to tf saved model (#6149)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and golechwierowicz committed Jan 12, 2024
1 parent 533271f commit efd5e58
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 28 deletions.
53 changes: 32 additions & 21 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import numpy as np
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import StableHLOExportOptions, exported_program_to_stablehlo
from torch_xla.tf_saved_model_integration import make_tf_function, save_torch_module_as_tf_saved_model
from torch_xla.tf_saved_model_integration import (
make_tf_function, save_torch_module_as_tf_saved_model,
save_stablehlo_graph_as_tf)
from torch.utils import _pytree as pytree
from torch.export import export, dynamic_dim
import torch
import torchvision

import tempfile
import unittest
Expand All @@ -13,31 +16,39 @@

class StableHLOInferenceTest(unittest.TestCase):

def test_resnet18_inference(self):
resnet18 = torchvision.models.resnet18().eval()
data = torch.randn(4, 3, 224, 224)
output = resnet18(data)
def test_dynamic_shapes(self):

exported = torch.export.export(resnet18, (data,))
options = StableHLOExportOptions(override_tracing_arguments=(data,))
stablehlo_program = exported_program_to_stablehlo(exported, options)
tf_func = make_tf_function(stablehlo_program)
class MyModule(torch.nn.Module):

output_tf = tf_func(*options.override_tracing_arguments)
output2 = torch.tensor(output_tf[0].numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))
def forward(self, a, b):
return a * b

def test_resnet18_save_load(self):
resnet18 = torchvision.models.resnet18().eval()
data = torch.randn(4, 3, 224, 224)
output = resnet18(data)
model = MyModule()
a = torch.randn(3, 10)
b = torch.randn(3, 10)
constraints = [
dynamic_dim(a, 0),
dynamic_dim(b, 0),
dynamic_dim(a, 0) == dynamic_dim(b, 0)
]

exported = torch.export.export(
model, (
a,
b,
), constraints=constraints)
shlo = exported_program_to_stablehlo(exported)
print(shlo.get_stablehlo_text())
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(resnet18, (data,), tempdir)
save_stablehlo_graph_as_tf(
shlo,
tempdir,
serving_key=tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
function_alias='')
loaded_m = tf.saved_model.load(tempdir)
res = loaded_m.f(data.detach().numpy())[0]
output2 = torch.tensor(res.numpy())
self.assertTrue(torch.allclose(output, output2, atol=1e-5))
data2 = (np.random.randn(2, 10), np.random.randn(2, 10))
res = loaded_m.f(*data2)
self.assertTrue(np.allclose(res, data2[0] * data2[1]))

def test_unused_param(self):

Expand Down
22 changes: 17 additions & 5 deletions torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class VariableType(enum.Enum):
class VariableSignature: # either argument or parameters
shape: List[int]
dtype: str
dynamic_dims: List[int] = dataclasses.field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -215,6 +216,13 @@ class XLAExportInterpreter(torch.fx.Interpreter):
def __init__(self, module, device):
self._device = device
super().__init__(module)
self.tensor_id_to_dynamic_dims = {}

def _mark_dynamic(self, tensor, dynamic_dims):
tid = torch_xla._XLAC._xla_get_tensor_id(tensor)
self.tensor_id_to_dynamic_dims[tid] = dynamic_dims
for i in dynamic_dims:
torch_xla._XLAC._xla_mark_dynamic(tensor, i)

def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
# NOTE(qihqi): We need to do this because there are some operators
Expand All @@ -232,9 +240,10 @@ def run_node(self, n) -> Any:
fake_t = n.meta['val']
res = super().run_node(n)
if hasattr(fake_t, 'shape'):
for i, x in enumerate(fake_t.shape):
if not isinstance(x, int):
torch_xla._XLAC._xla_mark_dynamic(res, i)
dynamic_dims = [
i for i, x in enumerate(fake_t.shape) if not isinstance(x, int)
]
self._mark_dynamic(res, dynamic_dims)
return res
return super().run_node(n)

Expand Down Expand Up @@ -297,8 +306,9 @@ def _exported_program_to_stablehlo_bundle(exported_model,
device = xm.xla_device()

# Run the fx graph tracing using lazy tensor
xla_interpreter = XLAExportInterpreter(exported_model.graph_module, device)
with torch.no_grad():
res = XLAExportInterpreter(exported_model.graph_module, device).run(
res = xla_interpreter.run(
*param_buffer_values,
*input_args,
*ordered_tensor_constants,
Expand Down Expand Up @@ -357,10 +367,12 @@ def _exported_program_to_stablehlo_bundle(exported_model,
location = InputLocation.constant(position=len(additional_constants))
additional_constants.append(tensor_value.detach().cpu().numpy())
input_locations.append(location)
dynamic_dims = xla_interpreter.tensor_id_to_dynamic_dims.get(tensor_id, [])
input_signatures.append(
VariableSignature(
shape=list(tensor_value.shape),
dtype=str(tensor_value.dtype).replace('torch.', '')))
dtype=str(tensor_value.dtype).replace('torch.', ''),
dynamic_dims=dynamic_dims))

unused_inputs = []
for i in unused_input_positions:
Expand Down
12 changes: 10 additions & 2 deletions torch_xla/tf_saved_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
raise


def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape


def _wrap_as_tf_func(func, bundle):

def inner(*args):
output_sig = func.meta.output_signature[0]
Touts = [sig.dtype for sig in func.meta.output_signature]
Souts = [sig.shape for sig in func.meta.output_signature]
Souts = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
return tfxla.call_module(
tuple(call_args),
Expand Down Expand Up @@ -54,8 +61,9 @@ def _make_input_signatures(
}
for i in range(len(input_pos_to_spec)):
spec = input_pos_to_spec[i]
shape = _get_shape_with_dynamic(spec)
yield tf.TensorSpec(
shape=spec.shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}')
shape=shape, dtype=getattr(tf, spec.dtype), name=f'args_{i}')


def _mangle_tf_root_scope_name(name):
Expand Down

0 comments on commit efd5e58

Please sign in to comment.