Skip to content

Commit

Permalink
Combine test_infra Tests (#1186)
Browse files Browse the repository at this point in the history
This change combines does two main things:
	- In `python/test_infra/test_ttir_ops_tt{nn,metal}.py`, each
	  test function used to have 3 decorators that always went
	  together. These have been combined into a single decorator.

	- `python/test_infra/test_ttir_ops_tt{nn,metal}.py` have been
	  combined into a single file, with an additional parameter
	  `targets` being used to select which backend is tested for.
  • Loading branch information
ctodTT committed Nov 12, 2024
1 parent 14cd5d0 commit 27841f6
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 359 deletions.
235 changes: 235 additions & 0 deletions python/test_infra/test_ttir_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# RUN: %python %s

import inspect

from ttmlir.test_utils import compile_and_convert
from ttmlir.ttir_builder import Operand, TTIRBuilder


@compile_and_convert(((128, 128),), test_name="test_exp")
def test_exp(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@compile_and_convert(((128, 128),), test_name="test_abs", targets=["ttnn"])
def test_abs(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)


@compile_and_convert(((128, 128),), test_name="test_logical_not", targets=["ttnn"])
def test_logical_not(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


@compile_and_convert(((128, 128),), test_name="test_neg", targets=["ttnn"])
def test_neg(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@compile_and_convert(((128, 128),), test_name="test_relu", targets=["ttnn"])
def test_relu(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@compile_and_convert(((128, 128),), test_name="test_sqrt", targets=["ttnn"])
def test_sqrt(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@compile_and_convert(((128, 128),), test_name="test_rsqrt", targets=["ttnn"])
def test_rsqrt(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)


@compile_and_convert(((128, 128),), test_name="test_sigmoid", targets=["ttnn"])
def test_sigmoid(in0: Operand, builder: TTIRBuilder):
return builder.sigmoid(in0)


@compile_and_convert(((128, 128),), test_name="test_reciprocal", targets=["ttnn"])
def test_reciprocal(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)


@compile_and_convert(
(
(64, 128),
(64, 128),
),
test_name="test_add",
)
def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_multiply",
)
def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_logical_and",
targets=["ttnn"],
)
def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_and(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_logical_or",
targets=["ttnn"],
)
def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_or(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_subtract",
targets=["ttnn"],
)
def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.subtract(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_eq",
targets=["ttnn"],
)
def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.eq(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_ne",
targets=["ttnn"],
)
def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ne(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_ge",
targets=["ttnn"],
)
def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ge(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_gt",
targets=["ttnn"],
)
def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.gt(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_le",
targets=["ttnn"],
)
def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.le(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_lt",
targets=["ttnn"],
)
def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.lt(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_div",
targets=["ttnn"],
)
def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.div(in0, in1)


@compile_and_convert(
(
(64, 64),
(64, 64),
),
test_name="test_maximum",
targets=["ttnn"],
)
def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.maximum(in0, in1)


@compile_and_convert(
(
(32, 32),
(32, 32),
(32, 32),
),
test_name="test_arbitrary_op_chain",
)
def test_arbitrary_op_chain(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
):
add = builder.add(in0, in1)
exp = builder.exp(in2)
return builder.multiply(add, exp)


if __name__ == "__main__":
test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
)

for function_name, func in test_functions:
if function_name.startswith("test_"):
func()
58 changes: 0 additions & 58 deletions python/test_infra/test_ttir_ops_ttmetal.py

This file was deleted.

Loading

0 comments on commit 27841f6

Please sign in to comment.