-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change combines does two main things: - In `python/test_infra/test_ttir_ops_tt{nn,metal}.py`, each test function used to have 3 decorators that always went together. These have been combined into a single decorator. - `python/test_infra/test_ttir_ops_tt{nn,metal}.py` have been combined into a single file, with an additional parameter `targets` being used to select which backend is tested for.
- Loading branch information
Showing
4 changed files
with
375 additions
and
359 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# RUN: %python %s | ||
|
||
import inspect | ||
|
||
from ttmlir.test_utils import compile_and_convert | ||
from ttmlir.ttir_builder import Operand, TTIRBuilder | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_exp") | ||
def test_exp(in0: Operand, builder: TTIRBuilder): | ||
return builder.exp(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_abs", targets=["ttnn"]) | ||
def test_abs(in0: Operand, builder: TTIRBuilder): | ||
return builder.abs(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_logical_not", targets=["ttnn"]) | ||
def test_logical_not(in0: Operand, builder: TTIRBuilder): | ||
return builder.logical_not(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_neg", targets=["ttnn"]) | ||
def test_neg(in0: Operand, builder: TTIRBuilder): | ||
return builder.neg(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_relu", targets=["ttnn"]) | ||
def test_relu(in0: Operand, builder: TTIRBuilder): | ||
return builder.relu(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_sqrt", targets=["ttnn"]) | ||
def test_sqrt(in0: Operand, builder: TTIRBuilder): | ||
return builder.sqrt(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_rsqrt", targets=["ttnn"]) | ||
def test_rsqrt(in0: Operand, builder: TTIRBuilder): | ||
return builder.rsqrt(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_sigmoid", targets=["ttnn"]) | ||
def test_sigmoid(in0: Operand, builder: TTIRBuilder): | ||
return builder.sigmoid(in0) | ||
|
||
|
||
@compile_and_convert(((128, 128),), test_name="test_reciprocal", targets=["ttnn"]) | ||
def test_reciprocal(in0: Operand, builder: TTIRBuilder): | ||
return builder.reciprocal(in0) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 128), | ||
(64, 128), | ||
), | ||
test_name="test_add", | ||
) | ||
def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.add(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_multiply", | ||
) | ||
def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.multiply(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_logical_and", | ||
targets=["ttnn"], | ||
) | ||
def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.logical_and(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_logical_or", | ||
targets=["ttnn"], | ||
) | ||
def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.logical_or(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_subtract", | ||
targets=["ttnn"], | ||
) | ||
def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.subtract(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_eq", | ||
targets=["ttnn"], | ||
) | ||
def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.eq(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_ne", | ||
targets=["ttnn"], | ||
) | ||
def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.ne(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_ge", | ||
targets=["ttnn"], | ||
) | ||
def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.ge(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_gt", | ||
targets=["ttnn"], | ||
) | ||
def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.gt(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_le", | ||
targets=["ttnn"], | ||
) | ||
def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.le(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_lt", | ||
targets=["ttnn"], | ||
) | ||
def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.lt(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_div", | ||
targets=["ttnn"], | ||
) | ||
def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.div(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(64, 64), | ||
(64, 64), | ||
), | ||
test_name="test_maximum", | ||
targets=["ttnn"], | ||
) | ||
def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder): | ||
return builder.maximum(in0, in1) | ||
|
||
|
||
@compile_and_convert( | ||
( | ||
(32, 32), | ||
(32, 32), | ||
(32, 32), | ||
), | ||
test_name="test_arbitrary_op_chain", | ||
) | ||
def test_arbitrary_op_chain( | ||
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder | ||
): | ||
add = builder.add(in0, in1) | ||
exp = builder.exp(in2) | ||
return builder.multiply(add, exp) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_functions = inspect.getmembers( | ||
inspect.getmodule(inspect.currentframe()), inspect.isfunction | ||
) | ||
|
||
for function_name, func in test_functions: | ||
if function_name.startswith("test_"): | ||
func() |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.