Skip to content

Commit

Permalink
Snoop test_infra Names (#1185)
Browse files Browse the repository at this point in the history
This change removes the necessity to specify the name of a test in the
`compile_and_convert` decorator, using `__name__` of the decorated
function to supply this instead. The option to override this choice
still exists
  • Loading branch information
ctodTT committed Nov 12, 2024
1 parent 27841f6 commit 5f4971f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 31 deletions.
32 changes: 9 additions & 23 deletions python/test_infra/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,47 @@
from ttmlir.ttir_builder import Operand, TTIRBuilder


@compile_and_convert(((128, 128),), test_name="test_exp")
@compile_and_convert(((128, 128),))
def test_exp(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@compile_and_convert(((128, 128),), test_name="test_abs", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_abs(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)


@compile_and_convert(((128, 128),), test_name="test_logical_not", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_logical_not(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


@compile_and_convert(((128, 128),), test_name="test_neg", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_neg(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@compile_and_convert(((128, 128),), test_name="test_relu", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_relu(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@compile_and_convert(((128, 128),), test_name="test_sqrt", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_sqrt(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@compile_and_convert(((128, 128),), test_name="test_rsqrt", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_rsqrt(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)


@compile_and_convert(((128, 128),), test_name="test_sigmoid", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_sigmoid(in0: Operand, builder: TTIRBuilder):
return builder.sigmoid(in0)


@compile_and_convert(((128, 128),), test_name="test_reciprocal", targets=["ttnn"])
@compile_and_convert(((128, 128),), targets=["ttnn"])
def test_reciprocal(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)

Expand All @@ -60,7 +60,6 @@ def test_reciprocal(in0: Operand, builder: TTIRBuilder):
(64, 128),
(64, 128),
),
test_name="test_add",
)
def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)
Expand All @@ -71,7 +70,6 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_multiply",
)
def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)
Expand All @@ -82,7 +80,6 @@ def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_logical_and",
targets=["ttnn"],
)
def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -94,7 +91,6 @@ def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_logical_or",
targets=["ttnn"],
)
def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -106,7 +102,6 @@ def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_subtract",
targets=["ttnn"],
)
def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -118,7 +113,6 @@ def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_eq",
targets=["ttnn"],
)
def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -130,7 +124,6 @@ def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_ne",
targets=["ttnn"],
)
def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -142,7 +135,6 @@ def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_ge",
targets=["ttnn"],
)
def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -154,7 +146,6 @@ def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_gt",
targets=["ttnn"],
)
def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -166,7 +157,6 @@ def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_le",
targets=["ttnn"],
)
def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -178,7 +168,6 @@ def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_lt",
targets=["ttnn"],
)
def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -190,7 +179,6 @@ def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_div",
targets=["ttnn"],
)
def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -202,7 +190,6 @@ def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder):
(64, 64),
(64, 64),
),
test_name="test_maximum",
targets=["ttnn"],
)
def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder):
Expand All @@ -215,7 +202,6 @@ def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder):
(32, 32),
(32, 32),
),
test_name="test_arbitrary_op_chain",
)
def test_arbitrary_op_chain(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
Expand Down
24 changes: 16 additions & 8 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def ttmetal_to_flatbuffer(

def compile_and_convert(
inputs_shapes: Tuple[Shape],
test_name: str,
test_name: Optional[str] = None,
targets: List[str] = ["ttmetal", "ttnn"],
module_dump: bool = False,
):
Expand All @@ -272,9 +272,10 @@ def compile_and_convert(
inputs_shapes: Tuple[Shape]
Shapes of the respective ranked tensor inputs of the test function.
test_name: str
The name of the decorated function. Used as the base name for dumped
files during the process
test_name: Optional[str]
The string to be used as the base name for dumped files throughout the
process. If `None` is provided, then the `__name__` of the decorated
function will be used.
targets: List[str]
A list that can only contain the following strings: 'ttnn' or
Expand All @@ -297,6 +298,13 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
"""

def decorator(test_fn: Callable):

# Snoop the name of `test_fn` if no override to the test name is provided
if test_name is None:
test_base = test_fn.__name__
else:
test_base = test_name

def wrapper():

# NOTE: since `ttir_to_tt{nn,metal} modifies the module in place,
Expand All @@ -305,13 +313,13 @@ def wrapper():

if "ttmetal" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module = ttir_to_ttmetal(module, builder, test_name + ".mlir")
ttmetal_to_flatbuffer(module, builder, test_name + ".ttm")
module = ttir_to_ttmetal(module, builder, test_base + ".mlir")
ttmetal_to_flatbuffer(module, builder, test_base + ".ttm")

if "ttnn" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module = ttir_to_ttnn(module, builder, test_name + ".mlir")
ttnn_to_flatbuffer(module, builder, test_name + ".ttnn")
module = ttir_to_ttnn(module, builder, test_base + ".mlir")
ttnn_to_flatbuffer(module, builder, test_base + ".ttnn")

return wrapper

Expand Down

0 comments on commit 5f4971f

Please sign in to comment.