From 44d218c8dfd73c6ee75d366b0500a15d533782ca Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 16 Feb 2024 18:25:23 -0800 Subject: [PATCH] Add fallback check to test_core_aten_ops.py (#6559) --- test/test_core_aten_ops.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 0846d7425ed..2819d9c45f6 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -1,5 +1,6 @@ import torch_xla import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met from torch_xla.stablehlo import exported_program_to_stablehlo from torch.utils import _pytree as pytree import torch @@ -53,6 +54,9 @@ def run_export_and_compare(testcase, kwargs2 = pytree.tree_map_only(torch.Tensor, lambda x: x.to(device=device), kwargs) res_xla = func(*args2, **kwargs2) + with testcase.subTest('torch_xla_metric'): + aten_function_name = f'aten::{str(func).split(".")[-1]}' + testcase.assertNotIn(aten_function_name, met.metrics_report()) with testcase.subTest('torch_xla_diff:' + str(atol)): diff_output( testcase, res, res_xla, atol=atol, rtol=rtol, equal_nan=equal_nan) @@ -71,6 +75,7 @@ class AtenOpTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) + met.clear_all() def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) @@ -1681,6 +1686,7 @@ def test_aten_glu_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.glu, args, kwargs) + @unittest.skip def test_aten_grid_sampler_2d_0(self): args = ( torch.randn((1, 3, 2, 10)).to(torch.float32), @@ -3063,6 +3069,7 @@ def test_aten_reciprocal_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs) + @unittest.skip def test_aten_reflection_pad1d_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -3074,6 +3081,7 @@ def test_aten_reflection_pad1d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad1d, args, kwargs) + @unittest.skip def test_aten_reflection_pad1d_1(self): args = ( torch.randint(0, 10, (10, 10)).to(torch.int32), @@ -3111,6 +3119,7 @@ def test_aten_reflection_pad2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad2d, args, kwargs) + @unittest.skip def test_aten_reflection_pad3d_0(self): args = ( torch.randn((3, 3, 3, 3, 3)).to(torch.float32), @@ -3126,6 +3135,7 @@ def test_aten_reflection_pad3d_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) + @unittest.skip def test_aten_reflection_pad3d_1(self): args = ( torch.randn((3, 3, 3, 3, 3)).to(torch.float16), @@ -3141,6 +3151,7 @@ def test_aten_reflection_pad3d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.reflection_pad3d, args, kwargs) + @unittest.skip def test_aten_reflection_pad3d_2(self): args = ( torch.randint(0, 10, (3, 3, 3, 3, 3)).to(torch.int32),