diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 7896c37d3..483e58f71 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -72,7 +72,9 @@ declare_mlir_python_sources(TTMLIRPythonSources.Passes declare_mlir_python_sources(TTMLIRPythonTestInfra.TestInfra ROOT_DIR "${TTMLIR_PYTHON_TEST_INFRA_ROOT_DIR}" ADD_TO_PARENT TTMLIRPythonTestInfra - SOURCES ttir_builder.py + SOURCES + ttir_builder.py + test_utils.py ) declare_mlir_python_extension(TTMLIRPythonExtensions.Main diff --git a/python/Passes.cpp b/python/Passes.cpp index c6010bee4..70fbbd667 100644 --- a/python/Passes.cpp +++ b/python/Passes.cpp @@ -5,7 +5,9 @@ #include "mlir/InitAllTranslations.h" #include "ttmlir/Bindings/Python/TTMLIRModule.h" #include "ttmlir/RegisterAll.h" +#include "ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h" #include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" +#include PYBIND11_MAKE_OPAQUE(std::shared_ptr); @@ -117,6 +119,30 @@ void populatePassesModule(py::module &m) { }, py::arg("module"), py::arg("options") = ""); + m.def( + "ttir_to_ttmetal_backend_pipeline", + [](MlirModule module, std::string options = "") { + mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); + mlir::PassManager pm(moduleOp->getName()); + mlir::DialectRegistry registry; + mlir::tt::registerAllDialects(registry); + mlir::tt::registerAllExtensions(registry); + mlir::MLIRContext *ctx = unwrap(mlirModuleGetContext(module)); + ctx->appendDialectRegistry(registry); + const auto *pipeline = + mlir::PassPipelineInfo::lookup("ttir-to-ttmetal-backend-pipeline"); + mlir::function_ref + err_handler = + [](const llvm::Twine &loc) { return mlir::failure(); }; + if (mlir::failed(pipeline->addToPipeline(pm, options, err_handler))) { + throw std::runtime_error("Failed to add pipeline to pass manager"); + } + if (mlir::failed(pm.run(moduleOp))) { + throw std::runtime_error("Failed to run pass manager"); + } + }, + py::arg("module"), py::arg("options") = ""); + py::class_>(m, "SharedVoidPtr") .def(py::init<>()) .def("from_ttnn", [](std::shared_ptr data, MlirModule module) { @@ -154,6 +180,22 @@ void populatePassesModule(py::module &m) { filepath); } }); + + m.def("ttmetal_to_flatbuffer_file", + [](MlirModule module, std::string &filepath) { + mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); + std::error_code fileError; + llvm::raw_fd_ostream file(filepath, fileError); + if (fileError) { + throw std::runtime_error("Failed to open file: " + filepath + + ". Error: " + fileError.message()); + } + if (mlir::failed(mlir::tt::ttmetal::translateTTMetalToFlatbuffer( + moduleOp, file))) { + throw std::runtime_error("Failed to write flatbuffer to file: " + + filepath); + } + }); } } // namespace mlir::ttmlir::python diff --git a/python/test_infra/test_ttir_ops_ttmetal.py b/python/test_infra/test_ttir_ops_ttmetal.py new file mode 100644 index 000000000..c166c3519 --- /dev/null +++ b/python/test_infra/test_ttir_ops_ttmetal.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# RUN: %python %s + +import inspect +import os + +from ttmlir.test_utils import ( + compile_as_mlir_module, + translate_ttnn_to_flatbuffer, + ttir_to_ttnn, + translate_ttmetal_to_flatbuffer, + ttir_to_ttmetal, +) +from ttmlir.ttir_builder import Operand, TTIRBuilder + +system_desc_path = os.getenv("SYSTEM_DESC_PATH", "") + + +@translate_ttmetal_to_flatbuffer(output_file_name="test_exp.ttm") +@ttir_to_ttmetal( + output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((128, 128)) +def test_exp_ttmetal(in0: Operand, builder: TTIRBuilder): + return builder.exp(in0) + + +@translate_ttmetal_to_flatbuffer(output_file_name="test_add.ttm") +@ttir_to_ttmetal( + output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 128), (64, 128)) +def test_add_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + +@translate_ttmetal_to_flatbuffer(output_file_name="test_multiply.ttm") +@ttir_to_ttmetal( + output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_multiply_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.multiply(in0, in1) + + +@translate_ttmetal_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttm") +@ttir_to_ttmetal( + output_file_name="test_arbitrary_op_chain.mlir", + system_desc_path=f"{system_desc_path}", +) +@compile_as_mlir_module((32, 32), (32, 32), (32, 32)) +def test_arbitrary_op_chain_ttmetal( + in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder +): + add = builder.add(in0, in1) + exp = builder.exp(in2) + mul = builder.multiply(add, exp) + in3 = builder.empty(builder.get_shape(mul)) + return builder.multiply(mul, in3) + + +if __name__ == "__main__": + test_functions = inspect.getmembers( + inspect.getmodule(inspect.currentframe()), inspect.isfunction + ) + + for function_name, func in test_functions: + if function_name.startswith("test_"): + func() diff --git a/python/test_infra/test_ttir_ops_ttnn.py b/python/test_infra/test_ttir_ops_ttnn.py new file mode 100644 index 000000000..7e3ba001c --- /dev/null +++ b/python/test_infra/test_ttir_ops_ttnn.py @@ -0,0 +1,217 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# RUN: %python %s + +import inspect +import os + +from ttmlir.test_utils import ( + compile_as_mlir_module, + translate_ttnn_to_flatbuffer, + ttir_to_ttnn, + translate_ttmetal_to_flatbuffer, + ttir_to_ttmetal, +) +from ttmlir.ttir_builder import Operand, TTIRBuilder + +system_desc_path = os.getenv("SYSTEM_DESC_PATH", "") + + +@translate_ttnn_to_flatbuffer(output_file_name="test_exp.ttnn") +@ttir_to_ttnn(output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((128, 128)) +def test_exp_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.exp(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_abs.ttnn") +@ttir_to_ttnn(output_file_name="test_abs.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((128, 128)) +def test_abs_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.abs(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_logical_not.ttnn") +@ttir_to_ttnn( + output_file_name="test_logical_not.mlir", + system_desc_path=f"{system_desc_path}", +) +@compile_as_mlir_module((128, 128)) +def test_logical_not_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.logical_not(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_neg.ttnn") +@ttir_to_ttnn(output_file_name="test_neg.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((128, 128)) +def test_neg_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.neg(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_relu.ttnn") +@ttir_to_ttnn(output_file_name="test_relu.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((128, 128)) +def test_relu_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.relu(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_sqrt.ttnn") +@ttir_to_ttnn(output_file_name="test_sqrt.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((128, 128)) +def test_sqrt_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.sqrt(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_rsqrt.ttnn") +@ttir_to_ttnn( + output_file_name="test_rsqrt.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((128, 128)) +def test_rsqrt_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.rsqrt(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_sigmoid.ttnn") +@ttir_to_ttnn( + output_file_name="test_sigmoid.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((128, 128)) +def test_sigmoid_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.sigmoid(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_reciprocal.ttnn") +@ttir_to_ttnn( + output_file_name="test_reciprocal.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((128, 128)) +def test_reciprocal_ttnn(in0: Operand, builder: TTIRBuilder): + return builder.reciprocal(in0) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_add.ttnn") +@ttir_to_ttnn(output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 128), (64, 128)) +def test_add_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_multiply.ttnn") +@ttir_to_ttnn( + output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_multiply_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.multiply(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_logical_and.ttnn") +@ttir_to_ttnn( + output_file_name="test_logical_and.mlir", + system_desc_path=f"{system_desc_path}", +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_logical_and_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.logical_and(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_logical_or.ttnn") +@ttir_to_ttnn( + output_file_name="test_logical_or.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_logical_or_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.logical_or(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_subtract.ttnn") +@ttir_to_ttnn( + output_file_name="test_subtract.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_subtract_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.subtract(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_eq.ttnn") +@ttir_to_ttnn(output_file_name="test_eq.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_eq_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.eq(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_ne.ttnn") +@ttir_to_ttnn(output_file_name="test_ne.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_ne_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.ne(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_ge.ttnn") +@ttir_to_ttnn(output_file_name="test_ge.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_ge_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.ge(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_gt.ttnn") +@ttir_to_ttnn(output_file_name="test_gt.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_gt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.gt(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_le.ttnn") +@ttir_to_ttnn(output_file_name="test_le.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_le_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.le(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_lt.ttnn") +@ttir_to_ttnn(output_file_name="test_lt.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_lt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.lt(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_div.ttnn") +@ttir_to_ttnn(output_file_name="test_div.mlir", system_desc_path=f"{system_desc_path}") +@compile_as_mlir_module((64, 64), (64, 64)) +def test_div_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.div(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_maximum.ttnn") +@ttir_to_ttnn( + output_file_name="test_maximum.mlir", system_desc_path=f"{system_desc_path}" +) +@compile_as_mlir_module((64, 64), (64, 64)) +def test_maximum_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.maximum(in0, in1) + + +@translate_ttnn_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttnn") +@ttir_to_ttnn( + output_file_name="test_arbitrary_op_chain.mlir", + system_desc_path=f"{system_desc_path}", +) +@compile_as_mlir_module((32, 32), (32, 32), (32, 32)) +def test_arbitrary_op_chain_ttnn( + in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder +): + add = builder.add(in0, in1) + exp = builder.exp(in2) + return builder.multiply(add, exp) + + +if __name__ == "__main__": + test_functions = inspect.getmembers( + inspect.getmodule(inspect.currentframe()), inspect.isfunction + ) + + for function_name, func in test_functions: + if function_name.startswith("test_"): + func() diff --git a/python/test_infra/test_utils.py b/python/test_infra/test_utils.py new file mode 100644 index 000000000..b5e6f36aa --- /dev/null +++ b/python/test_infra/test_utils.py @@ -0,0 +1,439 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Callable, Dict, Tuple + +import torch +from ttmlir.dialects import func +from ttmlir.ir import * +from ttmlir.passes import ( + ttir_to_ttnn_backend_pipeline, + ttnn_to_flatbuffer_file, + ttir_to_ttmetal_backend_pipeline, + ttmetal_to_flatbuffer_file, +) + +from .ttir_builder import Golden, Operand, Shape, TTIRBuilder + +TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "") + + +# ----- Static helpers used in this file only ----- + + +def _dump_module(module: Module) -> None: + """Just prints the module to console.""" + print(module) + + +def _run_ttmlir_translate_ttmetal( + input_file_name: str, output_file_name: str = "ttmetal_fb.ttm" +): + """ + Util function running `ttmlir-translate` tool on a file containing dumped TTMetal + module. It produces flatbuffer file `output_file_name`. + """ + import subprocess + + res = subprocess.run( + " ".join( + [ + f"ttmlir-translate", + "--ttmetal-to-flatbuffer", + input_file_name, + "-o", + output_file_name, + ] + ), + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert ( + res.returncode == 0 + ), f"Running ttmlir-translate failed with: {res.stdout.decode('utf-8')}" + return res + + +def _run_ttmlir_translate_ttnn( + input_file_name: str, output_file_name: str = "ttnn_fb.ttnn" +): + """ + Util function running `ttmlir-translate` tool on a file containing dumped TTNN + module. It produces flatbuffer file `output_file_name`. + """ + import subprocess + + res = subprocess.run( + " ".join( + [ + f"ttmlir-translate", + "--ttnn-to-flatbuffer", + input_file_name, + "-o", + output_file_name, + ] + ), + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert ( + res.returncode == 0 + ), f"Running ttmlir-translate failed with: {res.stdout.decode('utf-8')}" + return res + + +# ----- Decorators for doing passes and compiling to flatbuffer ----- + + +def compile_as_mlir_module( + *inputs_shapes: Tuple[Shape], + module_dump: bool = False, +): + """ + Decorator to define a MLIR module specified as a python function. + + It will wrap decorated test function in a MLIR FuncOp and then wrap that in a MLIR + module, and finally tie arguments of that FuncOp to test function inputs. It will + also pass a `TTIRBuilder` object as the last argument of test function. + + Arguments + --------- + inputs_shapes: Tuple[Shape] + Shapes of the respective ranked tensor inputs of the test function. + + module_dump: bool + Set to True to print out generated MLIR module. + + golden_dump: bool + Set to True to dump golden info to flatbuffer file. + + + Returns + ------- + MLIR module containing MLIR op graph defined by decorated test function. + + Example + ------- + + ```python + @compile_as_mlir_module((32, 32), (32, 32)) + def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + + test_add() # NOTE Called without arguments. + ``` + + which returns + + ``` + #any = #tt.operand_constraint<...> + module { + func.func @test_add( + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x32xf32> + ) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) ... + return %1 : tensor<32x32xf32> + } + } + ``` + + Check out: + https://github.com/llvm/llvm-project/blob/main/mlir/test/python/dialects/tensor.py + """ + + def decorator(test_fn: Callable): + # test_fn should be called with no args. + def wrapper(): + ctx = Context() + loc = Location.unknown(ctx) + # Instantiate builder which is passed as the last argument to + # `test_fn` so the user can use it to build ops. + builder = TTIRBuilder(ctx, loc) + + with ctx, loc: + test_fn_input_types = [ + builder.ranked_tensor_type(input_shape) + for input_shape in inputs_shapes + ] + + # Wrap everything in a mlir module. + module = Module.create() + + with InsertionPoint(module.body): + # Wrap everything in a mlir function. + @func.func(*test_fn_input_types, name=test_fn.__name__) + def decorated_func(*inputs): + # Randomly generate golden tensors for function inputs. + for i in inputs: + builder.generate_and_store_random_golden(i) + + return test_fn(*inputs, builder=builder) + + print( + f"`{test_fn.__name__}` sucessfully transformed into a MLIR module." + ) + + if module_dump: + _dump_module(module) + + return module + + return wrapper + + return decorator + + +def ttir_to_ttnn( + dump_to_file: bool = True, + output_file_name: str = "test.mlir", + system_desc_path: str = "", +): + """ + Converts TTIR module to TTNN module and optionally dumps to file. + + Wrapper around `ttir_to_ttnn_backend_pipeline` pybound pass. + + Arguments + --------- + dump_to_file: bool + Flag which indicates that generated TTNN module will be dumped to file. + + output_file_name: str + Name of the output file. + """ + + def decorator(fn: Callable): + def wrapper(*args, **kwargs): + # First, call the decorated function to get the MLIR module. + module = fn(*args, **kwargs) + + assert isinstance(module, Module), ( + f"Make sure this decorator is used on top of " + f"`compile_as_mlir_module` decorator." + ) + + # Now, pass it through the TTIR to TTNN pipeline. Module gets + # modified in place. + ttir_to_ttnn_backend_pipeline( + module, f"system-desc-path={system_desc_path}" + ) + + print("`ttir_to_ttnn_backend_pipeline` passed successfully.") + + # Optionally dump to file. + if dump_to_file: + with open(output_file_name, "w") as f: + f.write(str(module)) + + return output_file_name + + return wrapper + + return decorator + + +def ttir_to_ttmetal( + dump_to_file: bool = True, + output_file_name: str = "test.mlir", + return_module: bool = False, + system_desc_path: str = "", +): + """ + Converts TTIR module to TTMetal module and optionally dumps to file. + + Wrapper around `ttir_to_ttmetal_backend_pipeline` pybound pass. + + Arguments + --------- + dump_to_file: bool + Flag which indicates that generated TTMetal module will be dumped to file. + + output_file_name: str + Name of the output file. + + return_module: bool + Flag through which one chooses to return the generated module or name of the + file in which module was dumped (i.e. `output_file_name`). Exists only to + accommodate both `ttmetal_to_flatbuffer` and `translate_ttmetal_to_flatbuffer`. + """ + + def decorator(fn: Callable): + def wrapper(*args, **kwargs): + # First, call the decorated function to get the MLIR module. + module = fn(*args, **kwargs) + + assert isinstance(module, Module), ( + f"Make sure this decorator is used on top of " + f"`compile_as_mlir_module` decorator." + ) + + # Now, pass it through the TTIR to TTMetal pipeline. Module gets + # modified in place. + ttir_to_ttmetal_backend_pipeline( + module, f"system-desc-path={system_desc_path}" + ) + + print("`ttir_to_ttmetal_backend_pipeline` passed successfully.") + + # Optionally dump to file. + if dump_to_file: + with open(output_file_name, "w") as f: + f.write(str(module)) + + return module if return_module else output_file_name + + return wrapper + + return decorator + + +def ttmetal_to_flatbuffer( + output_file_name: str = "ttmetal_fb.ttmg", golden_info: Dict[Operand, Golden] = None +): + """ + NOTE NOT WORKING, DO NOT USE. + + Converts TTMetal module to flatbuffer and saves to file, meant to be used as a + decorator on top of `ttir_to_ttmetal` decorator. Take note that `ttir_to_ttmetal` + has to return module instead of file name if decorated with this decorator. + + Wrapper around `ttmetal_to_flatbuffer_file` pybound pass. + + TODO Optional golden info is passed to be embedded in flatbuffer as well. + + TODO Decorating a test function with this, i.e. calling + `ttmetal_to_flatbuffer_file` will result in + + 'LLVM ERROR: Building op `emitc.constant` but it isn't known in this MLIRContext: + the dialect may not be loaded or this operation hasn't been added by the dialect.' + + To circumvent this, `ttmlir-translate` is run on file that + `ttir_to_ttmetal_backend_pipeline` produces to generate TTMetal flatbuffer file, + which this decorator was supposed to generate. Use `translate_ttmetal_to_flatbuffer` + to achieve this, and make `ttir_to_ttmetal` return file name instead of module. + """ + + def decorator(test_fn: Callable): + def wrapper(*args, **kwargs): + # Get the TTMetal module by calling the wrapped function. + module = test_fn(*args, **kwargs) + + assert isinstance(module, Module), ( + f"Make sure `ttir_to_ttmetal` which was decorated with this function " + f"returns module, not file name." + ) + + # Convert to flatbuffer file. + ttmetal_to_flatbuffer_file(module, output_file_name) + + print("`ttmetal_to_flatbuffer_file` passed successfully.") + + return wrapper + + return decorator + + +def translate_ttmetal_to_flatbuffer(output_file_name: str = "ttmetal_fb.ttm"): + """ + NOTE Substitutes `ttmetal_to_flatbuffer` decorator. + + By running `ttmlir-translate` on input file, it produces TTMetal flatbuffer file + `output_file_name`, meant to be used as a decorator on top of `ttir_to_ttmetal` + decorator. Take note that `ttir_to_ttmetal` has to return file name instead of + module if decorated with this decorator. + + Wrapper around `ttmlir-translate` call. + + Example + ------- + + ```python + @translate_ttmetal_to_flatbuffer(output_file_name="ttmetal_fb_test_add.ttm") + @ttir_to_ttmetal(dump_to_file=True, output_file_name="test_add.mlir", return_module=False) + @compile_as_mlir_module((32, 32), (32, 32)) + def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + # CHECK: %0 = tensor.empty() : tensor<32x32xf32> + # CHECK: %1 = "ttir.add"(%arg0, %arg1, %0) + # CHECK: return %1 : tensor<32x32xf32> + + return builder.add(in0, in1) + ``` + """ + + def decorator(fn: Callable): + def wrapper(*args, **kwargs): + input_file_name = fn(*args, **kwargs) + + assert isinstance(input_file_name, str) and os.path.isfile( + input_file_name + ), ( + f"Make sure `ttir_to_ttmetal` which was decorated with this function " + f"returns file name, not module." + ) + + res = _run_ttmlir_translate_ttmetal(input_file_name, output_file_name) + + print( + f"Flatbuffer file for TTMetalBinary {output_file_name} successfully generated." + ) + + return res.returncode + + return wrapper + + return decorator + + +def translate_ttnn_to_flatbuffer(output_file_name: str = "ttnn_fb.ttnn"): + """ + + By running `ttmlir-translate` on input file, it produces TTNN flatbuffer file + `output_file_name`, meant to be used as a decorator on top of `ttir_to_ttnn` + decorator. + + Wrapper around `ttmlir-translate` call. + + Example + ------- + + ```python + @translate_ttnn_to_flatbuffer(output_file_name="ttnn_fb_test_add.ttm") + @ttir_to_ttnn(dump_to_file=True, output_file_name="test_add.mlir") + @compile_as_mlir_module((32, 32), (32, 32)) + def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + # CHECK: %0 = tensor.empty() : tensor<32x32xf32> + # CHECK: %1 = "ttir.add"(%arg0, %arg1, %0) + # CHECK: return %1 : tensor<32x32xf32> + + return builder.add(in0, in1) + ``` + """ + + def decorator(fn: Callable): + def wrapper(*args, **kwargs): + input_file_name = fn(*args, **kwargs) + assert isinstance(input_file_name, str) and os.path.isfile( + input_file_name + ), ( + f"Make sure `ttir_to_ttnn` which was decorated with this function " + f"returns file name, not module." + ) + + res = _run_ttmlir_translate_ttnn(input_file_name, output_file_name) + + print( + f"Flatbuffer file for TTNNBinary {output_file_name} successfully generated." + ) + + return res.returncode + + return wrapper + + return decorator diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index 0803439b9..122cf576d 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -59,6 +59,10 @@ def __init__(self, ctx: Context, location: Location): # ----- Public helpers ----- + @property + def goldens(self) -> Dict: + return self._goldens + def print_goldens(self) -> None: """ Prints saved operands and their respective goldens in descriptive form @@ -231,80 +235,194 @@ def empty( return op # ----- TTIR op factories ----- - - def add(self, in0: Operand, in1: Operand) -> OpView: - """Convenience wrapper constructing `ttir.AddOp`.""" - assert self.get_shape(in0) == self.get_shape( - in1 - ), "Elementwise `ttir.add` op expects inputs of same shape." - + def eltwise_proxy( + self, op_golden_function, op_ttir_function, inputs: List[Operand] + ) -> OpView: with self._ctx, self._loc: - output = self.empty(self.get_shape(in0)) + output = self.empty(self.get_shape(inputs[0])) - op = ttir.AddOp( + op = op_ttir_function( [self._get_type(output)], - [in0, in1], + inputs, [output], self._get_operand_constraint_attr(3), ) - golden = Golden( - torch.add(self._get_golden_tensor(in0), self._get_golden_tensor(in1)) - ) + goldens = [] + for input in inputs: + goldens.append(self._get_golden_tensor(input)) + + golden = Golden(op_golden_function(*goldens)) self._store_golden(op, golden) self._override_golden(output, golden) return op + def exp(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.exp, ttir.ExpOp, [in0]) + + def abs(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.abs, ttir.AbsOp, [in0]) + + def logical_not(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.logical_not, ttir.LogicalNotOp, [in0]) + + def neg(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.neg, ttir.NegOp, [in0]) + + def relu(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.relu, ttir.ReluOp, [in0]) + + def sqrt(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.sqrt, ttir.SqrtOp, [in0]) + + def rsqrt(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.rsqrt, ttir.RsqrtOp, [in0]) + + def sigmoid(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.sigmoid, ttir.SigmoidOp, [in0]) + + def reciprocal(self, in0: Operand) -> OpView: + return self.eltwise_proxy(torch.reciprocal, ttir.ReciprocalOp, [in0]) + + def add(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.add, ttir.AddOp, [in0, in1]) + def multiply(self, in0: Operand, in1: Operand) -> OpView: - """Convenience wrapper constructing `ttir.MultiplyOp`.""" - assert self.get_shape(in0) == self.get_shape( - in1 - ), "Elementwise `ttir.multiply` op expects inputs of same shape." + return self.eltwise_proxy(torch.multiply, ttir.MultiplyOp, [in0, in1]) - with self._ctx, self._loc: - output = self.empty(self.get_shape(in0)) + def logical_and(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.logical_and, ttir.LogicalAndOp, [in0, in1]) - op = ttir.MultiplyOp( - [self._get_type(output)], - [in0, in1], - [output], - self._get_operand_constraint_attr(3), - ) + def logical_or(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.logical_or, ttir.LogicalOrOp, [in0, in1]) - golden = Golden( - torch.multiply( - self._get_golden_tensor(in0), self._get_golden_tensor(in1) - ) - ) - self._store_golden(op, golden) - self._override_golden(output, golden) + def subtract(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.subtract, ttir.SubtractOp, [in0, in1]) - return op + def eq(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.eq, ttir.EqualOp, [in0, in1]) - def exp(self, in0: Operand) -> OpView: - """Convenience wrapper constructing `ttir.ExpOp`.""" - with self._ctx, self._loc: - output = self.empty(self.get_shape(in0)) + def ne(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.ne, ttir.NotEqualOp, [in0, in1]) - op = ttir.ExpOp( - [self._get_type(output)], - [in0], - [output], - self._get_operand_constraint_attr(3), - ) + def ge(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.ge, ttir.GreaterEqualOp, [in0, in1]) - golden = Golden(torch.exp(self._get_golden_tensor(in0))) - self._store_golden(op, golden) - self._override_golden(output, golden) + def gt(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.gt, ttir.GreaterThanOp, [in0, in1]) - return op + def le(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.le, ttir.LessEqualOp, [in0, in1]) + + def lt(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.lt, ttir.LessThanOp, [in0, in1]) + + def div(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.div, ttir.DivOp, [in0, in1]) + + def maximum(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.maximum, ttir.MaximumOp, [in0, in1]) + + +def compile_as_mlir_module( + *inputs_shapes: Tuple[Shape], + module_dump: bool = True, +): + """ + Decorator to define a MLIR module specified as a python function. + + It will wrap decorated test function in a MLIR FuncOp wrapped in a MLIR + module, and tie arguments of that FuncOp to test function inputs. It will + also pass a `TTIRBuilder` object as the last argument of test function. + + Arguments + --------- + inputs_shapes: Tuple[Shape] + Shapes of the respective ranked tensor inputs of the test function. + + module_dump: bool + Set to True if printout of generated MLIR module is wished. + + golden_dump: bool + Set to True if printout of generated goldens is wished. + + Example + ------- + + ```python + @compile_as_mlir_module((32, 32), (32, 32)) + def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + + test_add() # NOTE Called without arguments. + ``` + + which returns + + ``` + #any = #tt.operand_constraint<...> + module { + func.func @test_add( + %arg0: tensor<32x32xf32>, + %arg1: tensor<32x32xf32> + ) -> tensor<32x32xf32> { + %0 = tensor.empty() : tensor<32x32xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) ... + return %1 : tensor<32x32xf32> + } + } + ``` + + Check out: + https://github.com/llvm/llvm-project/blob/main/mlir/test/python/dialects/tensor.py + """ + + def decorator(test_fn: Callable): + # test_fn should be called with no args. + def wrapper(): + ctx = Context() + loc = Location.unknown(ctx) + # Instantiate builder which is passed as the last argument to + # `test_fn` so the user can use it to build ops. + builder = TTIRBuilder(ctx, loc) + + with ctx, loc: + test_fn_input_types = [ + builder.ranked_tensor_type(input_shape) + for input_shape in inputs_shapes + ] + + # Wrap everything in a mlir module. + module = Module.create() + + with InsertionPoint(module.body): + # Wrap everything in a mlir function. + @func.func(*test_fn_input_types, name=test_fn.__name__) + def decorated_func(*inputs): + # Randomly generate golden tensors for function inputs. + for i in inputs: + builder.generate_and_store_random_golden(i) + + return test_fn(*inputs, builder=builder) + + if module_dump: + print(module) + + if golden_dump: + builder.print_goldens() + + return module + + return wrapper + + return decorator def compile_as_mlir_module( *inputs_shapes: Tuple[Shape], module_dump: bool = True, - golden_dump: bool = False, ): """ Decorator to define a MLIR module specified as a python function. diff --git a/test/python/test_ttir_ops.py b/test/python/test_ttir_ops.py deleted file mode 100644 index 2590d517a..000000000 --- a/test/python/test_ttir_ops.py +++ /dev/null @@ -1,63 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -# RUN: %python %s | FileCheck %s - -from ttmlir.ttir_builder import TTIRBuilder, compile_as_mlir_module, Operand - - -@compile_as_mlir_module((32, 32), (32, 32)) -def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): - # CHECK: %0 = tensor.empty() : tensor<32x32xf32> - # CHECK: %1 = "ttir.add"(%arg0, %arg1, %0) - # CHECK: return %1 : tensor<32x32xf32> - - return builder.add(in0, in1) - - -@compile_as_mlir_module((64, 64), (64, 64)) -def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder): - # CHECK: %0 = tensor.empty() : tensor<64x64xf32> - # CHECK: %1 = "ttir.multiply"(%arg0, %arg1, %0) - # CHECK: return %1 : tensor<64x64xf32> - - return builder.multiply(in0, in1) - - -@compile_as_mlir_module((128, 128)) -def test_exp(in0: Operand, builder: TTIRBuilder): - # CHECK: %0 = tensor.empty() : tensor<128x128xf32> - # CHECK: %1 = "ttir.exp"(%arg0, %0) - # CHECK: return %1 : tensor<128x128xf32> - - return builder.exp(in0) - - -@compile_as_mlir_module((32, 32), (32, 32), (32, 32)) -def test_arbitrary_op_chain( - in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder -): - # CHECK: %0 = tensor.empty() : tensor<32x32xf32> - # CHECK: %1 = "ttir.add"(%arg0, %arg1, %0) - # CHECK: %2 = tensor.empty() : tensor<32x32xf32> - # CHECK: %3 = "ttir.exp"(%arg2, %2) - # CHECK: %4 = tensor.empty() : tensor<32x32xf32> - # CHECK: %5 = "ttir.multiply"(%1, %3, %4) - # CHECK: %6 = tensor.empty() : tensor<32x32xf32> - # CHECK: %7 = tensor.empty() : tensor<32x32xf32> - # CHECK: %8 = "ttir.multiply"(%5, %6, %7) - # CHECK: return %8 : tensor<32x32xf32> - - add = builder.add(in0, in1) - exp = builder.exp(in2) - mul = builder.multiply(add, exp) - in3 = builder.empty(builder.get_shape(mul)) - return builder.multiply(mul, in3) - - -if __name__ == "__main__": - test_add() - test_multiply() - test_exp() - test_arbitrary_op_chain()