Skip to content

Commit

Permalink
add complex64/128 ut for matrix_power
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi committed Aug 22, 2023
1 parent 2e5c013 commit 176b2bf
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions test/legacy_test/test_matrix_power_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,82 @@ def config(self):
self.n = -1


class TestMatrixPowerOpCP64(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex64"
self.n = 2

def test_grad(self):
self.check_grad(["X"], "Out", max_relative_error=1e-5)


class TestMatrixPowerOpBatchedCP64(TestMatrixPowerOpCP64):
def config(self):
self.matrix_shape = [2, 8, 4, 4]
self.dtype = "complex64"
self.n = 2


class TestMatrixPowerOpLarge1CP64(TestMatrixPowerOpCP64):
def config(self):
self.matrix_shape = [32, 32]
self.dtype = "complex64"
self.n = 2


class TestMatrixPowerOpLarge2CP64(TestMatrixPowerOpCP64):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex64"
self.n = 32


class TestMatrixPowerOpCP64Minus(TestMatrixPowerOpCP64):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex64"
self.n = -1


class TestMatrixPowerOpCP128(TestMatrixPowerOp):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex128"
self.n = 2

def test_grad(self):
self.check_grad(["X"], "Out", max_relative_error=1e-5)


class TestMatrixPowerOpBatchedCP128(TestMatrixPowerOpCP128):
def config(self):
self.matrix_shape = [2, 8, 4, 4]
self.dtype = "complex128"
self.n = 2


class TestMatrixPowerOpLarge1CP128(TestMatrixPowerOpCP128):
def config(self):
self.matrix_shape = [32, 32]
self.dtype = "complex128"
self.n = 2


class TestMatrixPowerOpLarge2CP128(TestMatrixPowerOpCP128):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex128"
self.n = 32


class TestMatrixPowerOpCP128Minus(TestMatrixPowerOpCP128):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = "complex128"
self.n = -1


class TestMatrixPowerAPI(unittest.TestCase):
def setUp(self):
np.random.seed(123)
Expand Down

0 comments on commit 176b2bf

Please sign in to comment.