Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Feb 29, 2024
1 parent 4c3c7b8 commit 9ea4600
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def _is_on_eager_debug_mode():
'skip on eager debug mode')


def _skipIfFunctionalization(disabled=True, reason=""):
verb = "is" if disabled else "is not"
reason = f" Reason: {reason}" if reason else ""
return unittest.skipIf(
XLA_DISABLE_FUNCTIONALIZATION,
f'Works only when functionalization {verb} disabled.{reason}.')


def skipIfFunctionalizationEnabled(reason):
return _skipIfFunctionalization(disabled=False, reason=reason)


def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(disabled=True, reason=reason)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -978,8 +994,7 @@ def func(a, b):

# TODO - upstream behavior has changed and results in expected DestroyXlaTensor
# counter as of 11/13/2023. Re-enable after reviewing the change.
@unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
@skipIfFunctionalizationDisabled("metrics differ")
def test_set(self):
met.clear_all()

Expand All @@ -997,8 +1012,7 @@ def test_set(self):
# shouldn't crash
self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10)))

@unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
@skipIfFunctionalizationDisabled("metrics differ")
def test_replace_xla_tensor(self):
met.clear_all()

Expand Down Expand Up @@ -1341,8 +1355,7 @@ def test_fn(t, c):
), dtype=torch.int64)
self.runAtenTest([token_type_ids, cat_ids], test_fn)

@unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION,
'When functionalization is enabled, views do not exist.')
@skipIfFunctionalizationEnabled("views do not exist")
def test_save_view_alias_check(self):

class Nested(object):
Expand Down Expand Up @@ -1498,41 +1511,47 @@ def test_fn(r):

self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 1))

self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 2))

self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (2, 1))

self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_and_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (4, 2))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_zero_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (0, 1))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
Expand All @@ -1541,6 +1560,7 @@ def test_fn(r):

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_empty_args(self):

def test_fn(r):
Expand Down

0 comments on commit 9ea4600

Please sign in to comment.