Skip to content

Commit

Permalink
Cleanup system description discovery (#1184)
Browse files Browse the repository at this point in the history
This change moves grabbing the path for the `ttrt` generated system
description from the env into decorators, so that it doesn't need to be
passed into every invocation of the decorator. This is a simple cleanup
change that does not change behavior.
  • Loading branch information
ctodTT committed Nov 7, 2024
1 parent a145ead commit 6ef8ed8
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 67 deletions.
20 changes: 4 additions & 16 deletions python/test_infra/test_ttir_ops_ttmetal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,29 @@
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@translate_ttmetal_to_flatbuffer(output_file_name="test_exp.ttm")
@ttir_to_ttmetal(
output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_exp.mlir")
@compile_as_mlir_module((128, 128))
def test_exp_ttmetal(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@translate_ttmetal_to_flatbuffer(output_file_name="test_add.ttm")
@ttir_to_ttmetal(
output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_add.mlir")
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@translate_ttmetal_to_flatbuffer(output_file_name="test_multiply.ttm")
@ttir_to_ttmetal(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_multiply.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@translate_ttmetal_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttm")
@ttir_to_ttmetal(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttmetal(output_file_name="test_arbitrary_op_chain.mlir")
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttmetal(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
Expand Down
71 changes: 23 additions & 48 deletions python/test_infra/test_ttir_ops_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,188 +16,163 @@
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@translate_ttnn_to_flatbuffer(output_file_name="test_exp.ttnn")
@ttir_to_ttnn(output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_exp.mlir")
@compile_as_mlir_module((128, 128))
def test_exp_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_abs.ttnn")
@ttir_to_ttnn(output_file_name="test_abs.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_abs.mlir")
@compile_as_mlir_module((128, 128))
def test_abs_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_not.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_not.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_logical_not.mlir")
@compile_as_mlir_module((128, 128))
def test_logical_not_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_neg.ttnn")
@ttir_to_ttnn(output_file_name="test_neg.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_neg.mlir")
@compile_as_mlir_module((128, 128))
def test_neg_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_relu.ttnn")
@ttir_to_ttnn(output_file_name="test_relu.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_relu.mlir")
@compile_as_mlir_module((128, 128))
def test_relu_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_sqrt.ttnn")
@ttir_to_ttnn(output_file_name="test_sqrt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_sqrt.mlir")
@compile_as_mlir_module((128, 128))
def test_sqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_rsqrt.ttnn")
@ttir_to_ttnn(
output_file_name="test_rsqrt.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_rsqrt.mlir")
@compile_as_mlir_module((128, 128))
def test_rsqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_sigmoid.ttnn")
@ttir_to_ttnn(
output_file_name="test_sigmoid.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_sigmoid.mlir")
@compile_as_mlir_module((128, 128))
def test_sigmoid_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sigmoid(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_reciprocal.ttnn")
@ttir_to_ttnn(
output_file_name="test_reciprocal.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_reciprocal.mlir")
@compile_as_mlir_module((128, 128))
def test_reciprocal_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_add.ttnn")
@ttir_to_ttnn(output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_add.mlir")
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_multiply.ttnn")
@ttir_to_ttnn(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_multiply.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_and.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_and.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_logical_and.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_and_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_and(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_or.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_or.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_logical_or.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_or_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_or(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_subtract.ttnn")
@ttir_to_ttnn(
output_file_name="test_subtract.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_subtract.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_subtract_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.subtract(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_eq.ttnn")
@ttir_to_ttnn(output_file_name="test_eq.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_eq.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_eq_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.eq(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_ne.ttnn")
@ttir_to_ttnn(output_file_name="test_ne.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_ne.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ne_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ne(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_ge.ttnn")
@ttir_to_ttnn(output_file_name="test_ge.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_ge.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ge_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ge(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_gt.ttnn")
@ttir_to_ttnn(output_file_name="test_gt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_gt.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_gt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.gt(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_le.ttnn")
@ttir_to_ttnn(output_file_name="test_le.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_le.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_le_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.le(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_lt.ttnn")
@ttir_to_ttnn(output_file_name="test_lt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_lt.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_lt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.lt(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_div.ttnn")
@ttir_to_ttnn(output_file_name="test_div.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_div.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_div_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.div(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_maximum.ttnn")
@ttir_to_ttnn(
output_file_name="test_maximum.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_maximum.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_maximum_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.maximum(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttnn")
@ttir_to_ttnn(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_arbitrary_op_chain.mlir")
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttnn(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
Expand Down
21 changes: 18 additions & 3 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Tuple, Optional

import torch
from ttmlir.dialects import func
Expand Down Expand Up @@ -193,7 +193,7 @@ def decorated_func(*inputs):
def ttir_to_ttnn(
dump_to_file: bool = True,
output_file_name: str = "test.mlir",
system_desc_path: str = "",
system_desc_path: Optional[str] = None,
):
"""
Converts TTIR module to TTNN module and optionally dumps to file.
Expand All @@ -209,8 +209,13 @@ def ttir_to_ttnn(
Name of the output file.
"""

# Default to the `SYSTEM_DESC_PATH` envvar
if system_desc_path is None:
system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")

def decorator(fn: Callable):
def wrapper(*args, **kwargs):

# First, call the decorated function to get the MLIR module.
module = fn(*args, **kwargs)

Expand Down Expand Up @@ -243,7 +248,7 @@ def ttir_to_ttmetal(
dump_to_file: bool = True,
output_file_name: str = "test.mlir",
return_module: bool = False,
system_desc_path: str = "",
system_desc_path: Optional[str] = None,
):
"""
Converts TTIR module to TTMetal module and optionally dumps to file.
Expand All @@ -264,16 +269,26 @@ def ttir_to_ttmetal(
accommodate both `ttmetal_to_flatbuffer` and `translate_ttmetal_to_flatbuffer`.
"""

# Default to the `SYSTEM_DESC_PATH` envvar
if system_desc_path is None:
system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")

def decorator(fn: Callable):
def wrapper(*args, **kwargs):

if dump_to_file is not None:
print("yay")

# First, call the decorated function to get the MLIR module.
module = fn(*args, **kwargs)


assert isinstance(module, Module), (
f"Make sure this decorator is used on top of "
f"`compile_as_mlir_module` decorator."
)


# Now, pass it through the TTIR to TTMetal pipeline. Module gets
# modified in place.
ttir_to_ttmetal_backend_pipeline(
Expand Down

0 comments on commit 6ef8ed8

Please sign in to comment.