Skip to content

Commit

Permalink
#1069: Added support for writing ttnn and ttmetal ops through python …
Browse files Browse the repository at this point in the history
…infrastructure and lowering it to flatbuffer files (#1096)
  • Loading branch information
tapspatel authored Oct 30, 2024
1 parent e1ccf6f commit 6b43a5a
Show file tree
Hide file tree
Showing 7 changed files with 939 additions and 112 deletions.
4 changes: 3 additions & 1 deletion python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ declare_mlir_python_sources(TTMLIRPythonSources.Passes
declare_mlir_python_sources(TTMLIRPythonTestInfra.TestInfra
ROOT_DIR "${TTMLIR_PYTHON_TEST_INFRA_ROOT_DIR}"
ADD_TO_PARENT TTMLIRPythonTestInfra
SOURCES ttir_builder.py
SOURCES
ttir_builder.py
test_utils.py
)

declare_mlir_python_extension(TTMLIRPythonExtensions.Main
Expand Down
42 changes: 42 additions & 0 deletions python/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include "mlir/InitAllTranslations.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include "ttmlir/RegisterAll.h"
#include "ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"
#include <cstdint>

PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);

Expand Down Expand Up @@ -117,6 +119,30 @@ void populatePassesModule(py::module &m) {
},
py::arg("module"), py::arg("options") = "");

m.def(
"ttir_to_ttmetal_backend_pipeline",
[](MlirModule module, std::string options = "") {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
mlir::PassManager pm(moduleOp->getName());
mlir::DialectRegistry registry;
mlir::tt::registerAllDialects(registry);
mlir::tt::registerAllExtensions(registry);
mlir::MLIRContext *ctx = unwrap(mlirModuleGetContext(module));
ctx->appendDialectRegistry(registry);
const auto *pipeline =
mlir::PassPipelineInfo::lookup("ttir-to-ttmetal-backend-pipeline");
mlir::function_ref<mlir::LogicalResult(const llvm::Twine &)>
err_handler =
[](const llvm::Twine &loc) { return mlir::failure(); };
if (mlir::failed(pipeline->addToPipeline(pm, options, err_handler))) {
throw std::runtime_error("Failed to add pipeline to pass manager");
}
if (mlir::failed(pm.run(moduleOp))) {
throw std::runtime_error("Failed to run pass manager");
}
},
py::arg("module"), py::arg("options") = "");

py::class_<std::shared_ptr<void>>(m, "SharedVoidPtr")
.def(py::init<>())
.def("from_ttnn", [](std::shared_ptr<void> data, MlirModule module) {
Expand Down Expand Up @@ -154,6 +180,22 @@ void populatePassesModule(py::module &m) {
filepath);
}
});

m.def("ttmetal_to_flatbuffer_file",
[](MlirModule module, std::string &filepath) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));
std::error_code fileError;
llvm::raw_fd_ostream file(filepath, fileError);
if (fileError) {
throw std::runtime_error("Failed to open file: " + filepath +
". Error: " + fileError.message());
}
if (mlir::failed(mlir::tt::ttmetal::translateTTMetalToFlatbuffer(
moduleOp, file))) {
throw std::runtime_error("Failed to write flatbuffer to file: " +
filepath);
}
});
}

} // namespace mlir::ttmlir::python
72 changes: 72 additions & 0 deletions python/test_infra/test_ttir_ops_ttmetal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# RUN: %python %s

import inspect
import os

from ttmlir.test_utils import (
compile_as_mlir_module,
translate_ttnn_to_flatbuffer,
ttir_to_ttnn,
translate_ttmetal_to_flatbuffer,
ttir_to_ttmetal,
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@translate_ttmetal_to_flatbuffer(output_file_name="test_exp.ttm")
@ttir_to_ttmetal(
output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((128, 128))
def test_exp_ttmetal(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@translate_ttmetal_to_flatbuffer(output_file_name="test_add.ttm")
@ttir_to_ttmetal(
output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@translate_ttmetal_to_flatbuffer(output_file_name="test_multiply.ttm")
@ttir_to_ttmetal(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@translate_ttmetal_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttm")
@ttir_to_ttmetal(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttmetal(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
):
add = builder.add(in0, in1)
exp = builder.exp(in2)
mul = builder.multiply(add, exp)
in3 = builder.empty(builder.get_shape(mul))
return builder.multiply(mul, in3)


if __name__ == "__main__":
test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
)

for function_name, func in test_functions:
if function_name.startswith("test_"):
func()
217 changes: 217 additions & 0 deletions python/test_infra/test_ttir_ops_ttnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

# RUN: %python %s

import inspect
import os

from ttmlir.test_utils import (
compile_as_mlir_module,
translate_ttnn_to_flatbuffer,
ttir_to_ttnn,
translate_ttmetal_to_flatbuffer,
ttir_to_ttmetal,
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@translate_ttnn_to_flatbuffer(output_file_name="test_exp.ttnn")
@ttir_to_ttnn(output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((128, 128))
def test_exp_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_abs.ttnn")
@ttir_to_ttnn(output_file_name="test_abs.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((128, 128))
def test_abs_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_not.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_not.mlir",
system_desc_path=f"{system_desc_path}",
)
@compile_as_mlir_module((128, 128))
def test_logical_not_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_neg.ttnn")
@ttir_to_ttnn(output_file_name="test_neg.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((128, 128))
def test_neg_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_relu.ttnn")
@ttir_to_ttnn(output_file_name="test_relu.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((128, 128))
def test_relu_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_sqrt.ttnn")
@ttir_to_ttnn(output_file_name="test_sqrt.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((128, 128))
def test_sqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_rsqrt.ttnn")
@ttir_to_ttnn(
output_file_name="test_rsqrt.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((128, 128))
def test_rsqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_sigmoid.ttnn")
@ttir_to_ttnn(
output_file_name="test_sigmoid.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((128, 128))
def test_sigmoid_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sigmoid(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_reciprocal.ttnn")
@ttir_to_ttnn(
output_file_name="test_reciprocal.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((128, 128))
def test_reciprocal_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)


@translate_ttnn_to_flatbuffer(output_file_name="test_add.ttnn")
@ttir_to_ttnn(output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_multiply.ttnn")
@ttir_to_ttnn(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_and.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_and.mlir",
system_desc_path=f"{system_desc_path}",
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_and_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_and(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_logical_or.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_or.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_or_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_or(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_subtract.ttnn")
@ttir_to_ttnn(
output_file_name="test_subtract.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_subtract_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.subtract(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_eq.ttnn")
@ttir_to_ttnn(output_file_name="test_eq.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_eq_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.eq(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_ne.ttnn")
@ttir_to_ttnn(output_file_name="test_ne.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ne_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ne(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_ge.ttnn")
@ttir_to_ttnn(output_file_name="test_ge.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ge_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ge(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_gt.ttnn")
@ttir_to_ttnn(output_file_name="test_gt.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_gt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.gt(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_le.ttnn")
@ttir_to_ttnn(output_file_name="test_le.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_le_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.le(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_lt.ttnn")
@ttir_to_ttnn(output_file_name="test_lt.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_lt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.lt(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_div.ttnn")
@ttir_to_ttnn(output_file_name="test_div.mlir", system_desc_path=f"{system_desc_path}")
@compile_as_mlir_module((64, 64), (64, 64))
def test_div_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.div(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_maximum.ttnn")
@ttir_to_ttnn(
output_file_name="test_maximum.mlir", system_desc_path=f"{system_desc_path}"
)
@compile_as_mlir_module((64, 64), (64, 64))
def test_maximum_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.maximum(in0, in1)


@translate_ttnn_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttnn")
@ttir_to_ttnn(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttnn(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
):
add = builder.add(in0, in1)
exp = builder.exp(in2)
return builder.multiply(add, exp)


if __name__ == "__main__":
test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
)

for function_name, func in test_functions:
if function_name.startswith("test_"):
func()
Loading

0 comments on commit 6b43a5a

Please sign in to comment.