Skip to content

Commit

Permalink
Run decomp before processing (pytorch#5713)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored and mbzomowski committed Nov 16, 2023
1 parent ce55d10 commit 1a8ba58
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
32 changes: 32 additions & 0 deletions test/stablehlo/test_exports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest
import torch
import torch.nn.functional as F
from torch_xla.stablehlo import exported_program_to_stablehlo


class Interpolate(torch.nn.Module):

def forward(self, masks: torch.Tensor) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(500, 500),
mode="bilinear",
align_corners=False,
)
return masks


class ExportTest(unittest.TestCase):

def test_interpolate(self):

arg = (torch.randn(3, 3, 200, 200),)
model = Interpolate()

ans = model(*arg)

with torch.no_grad():
exported = torch._export.export(model, arg)
shlo = exported_program_to_stablehlo(exported)
ans2 = shlo(*arg).cpu().to(torch.float32)
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))
2 changes: 1 addition & 1 deletion torch_xla/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
options) -> StableHLOModelBundle:
if options is None:
options = StableHLOExportOptions()

exported_model = exported_model.run_decompositions()
input_args = _extract_input_args(exported_model, options)

device = xm.xla_device()
Expand Down

0 comments on commit 1a8ba58

Please sign in to comment.