Skip to content

Commit

Permalink
Add more as_strided tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Feb 28, 2024
1 parent 0f84914 commit bc6409c
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,56 @@ def test_fn(r):

self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn)

def test_as_strided_with_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 1))

self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn)

def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 2))

self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn)

def test_as_strided_with_overlap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (2, 1))

self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn)

def test_as_strided_with_overlap_and_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (4, 2))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

def test_as_strided_with_overlap_zero_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (0, 1))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
x = r.view(8, 4)
return torch.as_strided(r, (4, 4), (6, 2))

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

def test_as_strided_with_empty_args(self):

def test_fn(r):
return torch.as_strided(r, tuple(), tuple())

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

def test_basic_bfloat16(self):

def test_fn(s):
Expand Down

0 comments on commit bc6409c

Please sign in to comment.