diff --git a/test/run_tests.sh b/test/run_tests.sh index a4da58990ec3..35e0b003dc8f 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -113,11 +113,6 @@ function run_pt_xla_debug { PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" } -function run_stablehlo_compile { - echo "Running in StableHlo Compile mode: $@" - XLA_STABLEHLO_COMPILE=1 run_test "$@" -} - function run_xla_backend_mp { echo "Running XLA backend multiprocessing test: $@" MASTER_ADDR=localhost MASTER_PORT=6000 run_test "$@" @@ -201,8 +196,9 @@ function run_xla_op_tests3 { # TODO(qihqi): this test require tensorflow to run. need to setup separate # CI with tf. run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_inference.py" - run_stablehlo_compile "$CDIR/stablehlo/test_stablehlo_compile.py" - run_stablehlo_compile "$CDIR/stablehlo/test_implicit_broadcasting.py" + run_test "$CDIR/stablehlo/test_stablehlo_compile.py" + run_test "$CDIR/stablehlo/test_implicit_broadcasting.py" + run_test "$CDIR/stablehlo/test_unbounded_dynamism.py" run_test "$CDIR/spmd/test_xla_sharding.py" run_test "$CDIR/spmd/test_xla_sharding_hlo.py" run_test "$CDIR/spmd/test_xla_virtual_device.py" diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 04e079089f2b..6d1c7d689413 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -819,13 +819,13 @@ def test_mark_sharding_ir(self): (0, 1)) hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%custom-call.10 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.9), custom_call_target="Sharding", sharding=', + '%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=', hlo) actual += 0 hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor]) self.assertIn( - '%add.15 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.13, f32[1,128]{1,0} %broadcast.14)', + '%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.10, f32[1,128]{1,0} %broadcast.11)', hlo) self.assertTrue(torch.allclose(expected, actual.cpu())) diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index c901f3684b1c..a7d2cfbbe533 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -82,6 +82,7 @@ def test_per_channel_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + @unittest.skip("Failed because PT2E BC break change on constant folding.") def test_resnet18(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) diff --git a/test/stablehlo/test_stablehlo_compile.py b/test/stablehlo/test_stablehlo_compile.py index d406f733c39b..e92892f520fc 100644 --- a/test/stablehlo/test_stablehlo_compile.py +++ b/test/stablehlo/test_stablehlo_compile.py @@ -1,11 +1,15 @@ +import os +import unittest + +import numpy as np +import torch import torch_xla import torch_xla.core.xla_model as xm -import torch -import torchvision -import unittest import torch_xla.debug.metrics as met import torch_xla.debug.metrics_compare_utils as mcu -import numpy as np +import torchvision + +os.environ['XLA_STABLEHLO_COMPILE'] = '1' class StableHloCompileTest(unittest.TestCase): diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index a4223b799aa5..55992cb383cf 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -1,3 +1,4 @@ +import re import sys import unittest @@ -14,45 +15,233 @@ class UnboundedDynamismExportTest(unittest.TestCase): - def test_simply_add(self): - a = torch.tensor([[1, 2], [2, 4]], device=device) - torch_xla._XLAC._xla_mark_dynamic(a, 0) - b = torch.tensor([[1, 2], [2, 4]], device=device) - torch_xla._XLAC._xla_mark_dynamic(b, 0) - c = a * b - hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c]) - self.assertTrue( - "(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content) - - def test_export_dynamism(self): + def _test_export_dynamism_wrapper(self, f, args, constraints): class M(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x, y): - return x * y + def forward(self, *args): + return f(*args) + + m = M() + ep = torch.export.export(m, args=args, constraints=constraints) + return ep + + def test_add(self): + args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + torch.export.dynamic_dim(args[1], 0), + torch.export.dynamic_dim(args[0], + 0) == torch.export.dynamic_dim(args[1], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'tensor<\?x197x768xf32>.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', + shlo_text) is not None) + + def test_add_scalar(self): + args = (torch.rand((10, 197, 768)), 0.345) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.add.Tensor, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'tensor.*tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>', + shlo_text) is not None) + + @unittest.skip("Unbounded Dynamism not supported on addmm.") + def test_addmm(self): + args = (torch.rand((5)), torch.rand((10, 5)), torch.rand((5, 5))) + constraints = [ + torch.export.dynamic_dim(args[1], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.addmm.default, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'tensor<\?x5xf32>.*->.*tensor<\?x5xf32>', shlo_text) + is not None) + + def test_bmm(self): + args = ( + torch.rand((24, 197, 64)), + torch.rand((24, 64, 197)), + ) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + torch.export.dynamic_dim(args[1], 0), + torch.export.dynamic_dim(args[0], + 0) == torch.export.dynamic_dim(args[1], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.bmm.default, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'%arg.: tensor<\?x64x197xf32>.*%arg.: tensor<\?x197x64xf32>.*->.*tensor<\?x197x197xf32>', + shlo_text) is not None) + + def test_cat(self): + args = ([torch.rand((10, 1, 768)), torch.rand((10, 196, 768))], 1) + constraints = [ + torch.export.dynamic_dim(args[0][0], 0), + torch.export.dynamic_dim(args[0][1], 0), + torch.export.dynamic_dim(args[0][0], + 0) == torch.export.dynamic_dim(args[0][1], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.cat.default, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r'%arg.: tensor<\?x196x768xf32>.*%arg.: tensor<\?x1x768xf32>.*->.*tensor<\?x197x768xf32>', + shlo_text) is not None) + + @unittest.skip("Unbounded Dynamism not supported on conv.") + def test_conv(self): + args = ( + torch.rand((10, 3, 224, 224)), + torch.rand((5, 3, 16, 16)), + torch.rand((5)), + [16, 16], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + ) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + torch.export.dynamic_dim(args[0], 0) < 16, + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.convolution.default, + args, constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'tensor<\?x3x224x224xf32>.*->.*tensor<\?x5x14x14xf32>', + shlo_text) is not None) + + def test_div(self): + args = (torch.rand((10, 12, 197)), 8.0) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.div.Tensor, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r'tensor<\?x12x197xf32>.*->.*tensor<\?x12x197xf32>', + shlo_text) is not None) + + @unittest.skip("xla::Erf doesn't support unbounded dynamic input.") + def test_gelu(self): + args = (torch.rand((3, 5)),) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.gelu, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + # shlo_text = shlo_module.get_stablehlo_text() + # self.assertTrue( + # "(%arg0: tensor, %arg1: tensor) -> tensor" in + # shlo_text) + + @unittest.skip("Unbounded Dynamism not supported on view.") + def test_native_layer_norm(self): + args = ( + torch.rand((10, 197, 768)), + [768], + torch.rand((768)), + torch.rand((768)), + 1e-12, + ) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper( + torch.ops.aten.native_layer_norm.default, args, constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x197x768xf32>", + shlo_text) is not None) + + def test_permute(self): + args = (torch.rand((10, 197, 12, 64)), [0, 2, 1, 3]) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.permute.default, + args, constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r"%arg.: tensor<\?x197x12x64xf32>.*->.*tensor<\?x12x197x64xf32>", + shlo_text) is not None) + + @unittest.skip("Unbounded Dynamism not supported on select..") + def test_select(self): + args = (torch.rand((10, 197, 768)), 1, 0) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.select.int, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search(r"%arg.: tensor<\?x197x768xf32>.*->.*tensor<\?x768xf32>", + shlo_text) is not None) + + @unittest.skip("Unbounded Dynamism not supported on slice.") + def test_slice(self): + args = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) + constraints = [ + torch.export.dynamic_dim(args[0], 0), + ] + ep = self._test_export_dynamism_wrapper(torch.ops.aten.slice.Tensor, args, + constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x3x224x224xf32>", + shlo_text) is not None) - example_args = (torch.tensor([[1, 2], [2, 4]], device=device), - torch.tensor([[1, 2], [2, 4]], device=device)) + @unittest.skip("Unbounded Dynamism not supported on softmax.") + def test_softmax(self): + args = (torch.rand((10, 12, 197, 197)), -1, False) constraints = [ - # First dimension of each input is a dynamic batch size - torch.export.dynamic_dim(example_args[0], 0), - torch.export.dynamic_dim(example_args[1], 0), - # The dynamic batch size between the inputs are equal - torch.export.dynamic_dim(example_args[0], - 0) == torch.export.dynamic_dim( - example_args[1], 0), + torch.export.dynamic_dim(args[0], 0), ] - ep = torch.export.export(M(), args=example_args, constraints=constraints) + ep = self._test_export_dynamism_wrapper(torch.ops.aten._softmax.default, + args, constraints) shlo_module = exported_program_to_stablehlo(ep) - shlo_text = shlo_module.get_stablehlo_text("forward") + shlo_text = shlo_module.get_stablehlo_text() self.assertTrue( - "(%arg0: tensor, %arg1: tensor) -> tensor" in - shlo_text) + re.search( + r"%arg.: tensor<\?x12x197x197xf32>.*->.*tensor<\?x12x197x197xf32>", + shlo_text) is not None) -if __name__ == '__main__': +if __name__ == "__main__": test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 453d3a1a8958..acc63b271356 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -751,14 +751,20 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other, xla::Shape input_shape = input->shape().get(); xla::Shape other_shape = other->shape().get(); torch::lazy::Value constant; + const torch::lazy::BackendDevice& device = input->GetDevice(); if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) { constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, other->shape(), logical_element_type, input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(other->dtype(), &device)), + logical_element_type, device); } else { SymIntElements sym_int_elements(other->GetIrValue()); constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, other->shape(), sym_int_elements, logical_element_type, - input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(other->dtype(), &device)), + sym_int_elements, logical_element_type, device); } return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant, @@ -768,12 +774,19 @@ XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other, XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other, const at::Scalar& alpha, c10::optional logical_element_type) { + const torch::lazy::BackendDevice& device = input->GetDevice(); torch::lazy::Value other_constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), logical_element_type, input->GetDevice()); + other, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(input->dtype(), &device)), + logical_element_type, device); torch::lazy::Value alpha_constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - alpha, input->shape(), logical_element_type, input->GetDevice()); + alpha, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(input->dtype(), &device)), + logical_element_type, device); return input->CreateFrom( input->GetIrValue() + other_constant * alpha_constant, logical_element_type); @@ -1860,8 +1873,12 @@ XLATensorPtr mul(const XLATensorPtr& input, const XLATensorPtr& other, XLATensorPtr mul(const XLATensorPtr& input, const at::Scalar& other, c10::optional logical_element_type) { + const torch::lazy::BackendDevice& device = input->GetDevice(); torch::lazy::Value constant = XLAGraphExecutor::Get()->GetIrValueForScalar( - other, input->shape(), logical_element_type, input->GetDevice()); + other, + xla::ShapeUtil::MakeScalarShape( + MakeXlaPrimitiveType(input->dtype(), &device)), + logical_element_type, device); return input->CreateFrom(input->GetIrValue() * constant, logical_element_type); }