diff --git a/include/ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h b/include/ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h new file mode 100644 index 000000000..b93e65f5f --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_TTNNTOSERIALIZEDBINARY_H +#define TTMLIR_DIALECT_TTNN_TRANSFORMS_TTNNTOSERIALIZEDBINARY_H + +#include + +#include "mlir/IR/BuiltinOps.h" + +namespace mlir::tt::ttnn { +std::shared_ptr emitTTNNAsFlatbuffer(OwningOpRef &moduleOp); +} // namespace mlir::tt::ttnn +#endif diff --git a/lib/Dialect/TTNN/Transforms/CMakeLists.txt b/lib/Dialect/TTNN/Transforms/CMakeLists.txt index 975ce2e3b..1e5f37c70 100644 --- a/lib/Dialect/TTNN/Transforms/CMakeLists.txt +++ b/lib/Dialect/TTNN/Transforms/CMakeLists.txt @@ -1,7 +1,7 @@ add_mlir_dialect_library(MLIRTTNNTransforms Passes.cpp TTNNToCpp.cpp - SerializeToBinary.cpp + TTNNToSerializedBinary.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir diff --git a/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp b/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp similarity index 94% rename from lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp rename to lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp index 73904feed..568d29f9d 100644 --- a/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp @@ -17,6 +17,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" #include "ttmlir/Dialect/TTNN/Passes.h" #include "ttmlir/Dialect/TTNN/Transforms/TTNNToCpp.h" +#include "ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Target/Utils/FlatbufferObjectCache.h" #include "ttmlir/Target/Utils/FuncOpToProgram.h" @@ -263,4 +264,23 @@ class TTNNSerializeToBinary std::shared_ptr serializedBinary; }; +std::shared_ptr emitTTNNAsFlatbuffer(OwningOpRef &moduleOp) { + auto pm = PassManager::on(moduleOp.get().getContext()); + auto pass = createTTNNSerializeToBinary(); + Pass *basePass = pass.get(); + pm.addPass(std::move(pass)); + + // Run the pass manager. + if (failed(pm.run(moduleOp.get()))) { + return nullptr; + } + + auto *derivedPass = llvm::dyn_cast(basePass); + if (!derivedPass) { + return nullptr; + } + + return derivedPass->getBinary(); +} + } // namespace mlir::tt::ttnn