From 9ea460065b9116feda767b5d9debb349891bb469 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 29 Feb 2024 10:48:34 -0300 Subject: [PATCH] Fix tests. --- test/test_operations.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index b32d3680bd5a..84479255c481 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -67,6 +67,22 @@ def _is_on_eager_debug_mode(): 'skip on eager debug mode') +def _skipIfFunctionalization(disabled=True, reason=""): + verb = "is" if disabled else "is not" + reason = f" Reason: {reason}" if reason else "" + return unittest.skipIf( + XLA_DISABLE_FUNCTIONALIZATION, + f'Works only when functionalization {verb} disabled.{reason}.') + + +def skipIfFunctionalizationEnabled(reason): + return _skipIfFunctionalization(disabled=False, reason=reason) + + +def skipIfFunctionalizationDisabled(reason): + return _skipIfFunctionalization(disabled=True, reason=reason) + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -978,8 +994,7 @@ def func(a, b): # TODO - upstream behavior has changed and results in expected DestroyXlaTensor # counter as of 11/13/2023. Re-enable after reviewing the change. - @unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION, - 'Metrics differ when functionalization is disabled.') + @skipIfFunctionalizationDisabled("metrics differ") def test_set(self): met.clear_all() @@ -997,8 +1012,7 @@ def test_set(self): # shouldn't crash self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) - @unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION, - 'Metrics differ when functionalization is disabled.') + @skipIfFunctionalizationDisabled("metrics differ") def test_replace_xla_tensor(self): met.clear_all() @@ -1341,8 +1355,7 @@ def test_fn(t, c): ), dtype=torch.int64) self.runAtenTest([token_type_ids, cat_ids], test_fn) - @unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION, - 'When functionalization is enabled, views do not exist.') + @skipIfFunctionalizationEnabled("views do not exist") def test_save_view_alias_check(self): class Nested(object): @@ -1498,6 +1511,7 @@ def test_fn(r): self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_gap(self): def test_fn(r): @@ -1505,6 +1519,7 @@ def test_fn(r): self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_gap_no_unit_stride(self): def test_fn(r): @@ -1512,6 +1527,7 @@ def test_fn(r): self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_overlap(self): def test_fn(r): @@ -1519,6 +1535,7 @@ def test_fn(r): self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_overlap_and_gap(self): def test_fn(r): @@ -1526,6 +1543,7 @@ def test_fn(r): self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_overlap_zero_stride(self): def test_fn(r): @@ -1533,6 +1551,7 @@ def test_fn(r): self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_gap_no_unit_stride(self): def test_fn(r): @@ -1541,6 +1560,7 @@ def test_fn(r): self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") def test_as_strided_with_empty_args(self): def test_fn(r):