From 176e51f4540d43bca1c6d96f4d8d9cb4f2452337 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 19 Oct 2023 20:22:38 +0000 Subject: [PATCH] Run decomp before processing --- test/stablehlo/test_exports.py | 32 ++++++++++++++++++++++++++++++++ torch_xla/stablehlo.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 test/stablehlo/test_exports.py diff --git a/test/stablehlo/test_exports.py b/test/stablehlo/test_exports.py new file mode 100644 index 00000000000..ba99319c6b2 --- /dev/null +++ b/test/stablehlo/test_exports.py @@ -0,0 +1,32 @@ +import unittest +import torch +import torch.nn.functional as F +from torch_xla.stablehlo import exported_program_to_stablehlo + + +class Interpolate(torch.nn.Module): + + def forward(self, masks: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(500, 500), + mode="bilinear", + align_corners=False, + ) + return masks + + +class ExportTest(unittest.TestCase): + + def test_interpolate(self): + + arg = (torch.randn(3, 3, 200, 200),) + model = Interpolate() + + ans = model(*arg) + + with torch.no_grad(): + exported = torch._export.export(model, arg) + shlo = exported_program_to_stablehlo(exported) + ans2 = shlo(*arg).cpu().to(torch.float32) + self.assertTrue(torch.allclose(ans, ans2, atol=1e-5)) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index 557f1271611..e6916e08fb8 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -240,7 +240,7 @@ def _exported_program_to_stablehlo_bundle(exported_model, options) -> StableHLOModelBundle: if options is None: options = StableHLOExportOptions() - + exported_model = exported_model.run_decompositions() input_args = _extract_input_args(exported_model, options) device = xm.xla_device()