Skip to content

Commit

Permalink
Workarounding flatbuffer generation API (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT authored Jul 12, 2024
1 parent d8b6ced commit dbe0aec
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
15 changes: 15 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_TTNNTOSERIALIZEDBINARY_H
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_TTNNTOSERIALIZEDBINARY_H

#include <memory>

#include "mlir/IR/BuiltinOps.h"

namespace mlir::tt::ttnn {
std::shared_ptr<void> emitTTNNAsFlatbuffer(OwningOpRef<ModuleOp> &moduleOp);
} // namespace mlir::tt::ttnn
#endif
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRTTNNTransforms
Passes.cpp
TTNNToCpp.cpp
SerializeToBinary.cpp
TTNNToSerializedBinary.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Passes.h"
#include "ttmlir/Dialect/TTNN/Transforms/TTNNToCpp.h"
#include "ttmlir/Dialect/TTNN/Transforms/TTNNToSerializedBinary.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Target/Utils/FlatbufferObjectCache.h"
#include "ttmlir/Target/Utils/FuncOpToProgram.h"
Expand Down Expand Up @@ -263,4 +264,23 @@ class TTNNSerializeToBinary
std::shared_ptr<void> serializedBinary;
};

std::shared_ptr<void> emitTTNNAsFlatbuffer(OwningOpRef<ModuleOp> &moduleOp) {
auto pm = PassManager::on<ModuleOp>(moduleOp.get().getContext());
auto pass = createTTNNSerializeToBinary();
Pass *basePass = pass.get();
pm.addPass(std::move(pass));

// Run the pass manager.
if (failed(pm.run(moduleOp.get()))) {
return nullptr;
}

auto *derivedPass = llvm::dyn_cast<TTNNSerializeToBinary>(basePass);
if (!derivedPass) {
return nullptr;
}

return derivedPass->getBinary();
}

} // namespace mlir::tt::ttnn

0 comments on commit dbe0aec

Please sign in to comment.