-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change ttnn->emitc transform to use dialect conversion API (#190)
* Change ttnn->emitc transform to use dialect conversion API * pr fixes v1 * update docs * fix ttnn-serialize-to-binary pipeline
- Loading branch information
1 parent
2fcb239
commit 9f3982e
Showing
19 changed files
with
353 additions
and
180 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
add_subdirectory(Conversion) | ||
add_subdirectory(Dialect) | ||
add_subdirectory(Target) | ||
add_dependencies(mlir-headers FBS_GENERATION) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion) | ||
add_public_tablegen_target(TTMLIRConversionPassIncGen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_CONVERSION_PASSES_H | ||
#define TTMLIR_CONVERSION_PASSES_H | ||
|
||
#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNN.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
|
||
namespace mlir::tt { | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "ttmlir/Conversion/Passes.h.inc" | ||
|
||
} // namespace mlir::tt | ||
|
||
#endif // TTMLIR_CONVERSION_PASSES_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_CONVERSION_PASSES | ||
#define TTMLIR_CONVERSION_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ConvertTTNNToEmitC : Pass<"convert-ttnn-to-emitc", "::mlir::func::FuncOp"> { | ||
let summary = "Convert TTNN dialect to EmitC dialect."; | ||
let constructor = "createConvertTTNNToEmitCPass()"; | ||
let dependentDialects = ["mlir::emitc::EmitCDialect", "mlir::tt::ttnn::TTNNDialect"]; | ||
} | ||
|
||
#endif // TTMLIR_CONVERSION_PASSES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H | ||
#define TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir::tt { | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTTNNToEmitCPass(); | ||
|
||
} // namespace mlir::tt | ||
|
||
#endif // TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(TTNNToEmitC) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_LIB_CONVERSION_PASSDETAIL_H | ||
#define TTMLIR_LIB_CONVERSION_PASSDETAIL_H | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNN.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir::emitc { | ||
class EmitCDialect; | ||
} | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "ttmlir/Conversion/Passes.h.inc" | ||
|
||
} // namespace mlir::tt::ttnn | ||
|
||
#endif // TTMLIR_LIB_CONVERSION_PASSDETAIL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
add_mlir_library(TTMLIRTTNNToEmitC | ||
TTNNToEmitC.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TTNNToEmitC | ||
|
||
DEPENDS | ||
TTMLIRConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
MLIRTransformUtils | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H | ||
#define TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
template <typename SrcOp, typename Adaptor = typename SrcOp::Adaptor> | ||
class DefaultOpConversionPattern : public OpConversionPattern<SrcOp> { | ||
using OpConversionPattern<SrcOp>::OpConversionPattern; | ||
|
||
public: | ||
// Default op conversion pattern, used to convert most ops | ||
// | ||
DefaultOpConversionPattern(MLIRContext *ctx) | ||
: OpConversionPattern<SrcOp>(ctx) {} | ||
|
||
DefaultOpConversionPattern(const TypeConverter &typeConverter, | ||
MLIRContext *context, PatternBenefit benefit = 1) | ||
: OpConversionPattern<SrcOp>(typeConverter, context, benefit) {} | ||
|
||
// Converts op name by removing the dialect prefix ("ttnn.") and replacing | ||
// with namespace prefix ("ttnn::") | ||
// | ||
std::string convertOpName(SrcOp op) const { | ||
auto name = op.getOperationName(); | ||
assert( | ||
name.starts_with("ttnn.") && | ||
"DefaultOpConversionPattern only supports ops from the TTNN dialect"); | ||
|
||
return name.str().replace(0, 5, "ttnn::"); | ||
} | ||
|
||
LogicalResult | ||
matchAndRewrite(SrcOp srcOp, Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
int numReturnTypes = srcOp->getResultTypes().size(); | ||
assert(numReturnTypes <= 1 && | ||
"DefaultOpConversionPattern does not support multiple return types"); | ||
|
||
// If srcOp has a return type, cast it before converting | ||
// | ||
if (numReturnTypes == 1) { | ||
auto resultTy = cast<emitc::OpaqueType>( | ||
this->getTypeConverter()->convertType(srcOp->getResult(0).getType())); | ||
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( | ||
srcOp, resultTy, convertOpName(srcOp), nullptr, nullptr, | ||
adaptor.getOperands()); | ||
} else { | ||
// No return type, only convert the op | ||
// | ||
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>( | ||
srcOp, srcOp->getResultTypes(), convertOpName(srcOp), nullptr, | ||
nullptr, adaptor.getOperands()); | ||
} | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
#endif // TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" | ||
|
||
#include "../PassDetail.h" | ||
#include "PopulatePatterns.h" | ||
#include "TypeConverter.h" | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsDialect.h.inc" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNN.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::tt; | ||
|
||
namespace { | ||
|
||
void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, | ||
mlir::RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
// Device ops | ||
// | ||
patterns.add<DefaultOpConversionPattern<ttnn::OpenDeviceOp>>(typeConverter, | ||
ctx, true); | ||
patterns.add<DefaultOpConversionPattern<ttnn::CloseDeviceOp>>(typeConverter, | ||
ctx); | ||
|
||
// Memory ops | ||
// | ||
patterns.add<DefaultOpConversionPattern<ttnn::ToMemoryConfigOp>>( | ||
typeConverter, ctx); | ||
|
||
// Tensor ops | ||
// | ||
patterns.add<DefaultOpConversionPattern<ttnn::FullOp>>(typeConverter, ctx); | ||
|
||
// Math ops | ||
// | ||
patterns.add<DefaultOpConversionPattern<ttnn::AddOp>>(typeConverter, ctx); | ||
patterns.add<DefaultOpConversionPattern<ttnn::SubtractOp>>(typeConverter, | ||
ctx); | ||
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>>(typeConverter, ctx); | ||
patterns.add<DefaultOpConversionPattern<ttnn::MultiplyOp>>(typeConverter, | ||
ctx); | ||
patterns.add<DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx); | ||
patterns.add<DefaultOpConversionPattern<ttnn::ReluOp>>(typeConverter, ctx); | ||
} | ||
|
||
struct ConvertTTNNToEmitCPass | ||
: public ttnn::ConvertTTNNToEmitCBase<ConvertTTNNToEmitCPass> { | ||
void runOnOperation() override { | ||
mlir::ConversionTarget target(getContext()); | ||
|
||
target.addLegalDialect<func::FuncDialect>(); | ||
target.addLegalDialect<emitc::EmitCDialect>(); | ||
target.addIllegalDialect<ttnn::TTNNDialect>(); | ||
|
||
// Add header imports to front of module | ||
// | ||
{ | ||
auto module = getOperation(); | ||
OpBuilder builder(module); | ||
|
||
builder.create<emitc::IncludeOp>(module.getLoc(), "ttnn/device.h", | ||
/*isStandard=*/false); | ||
builder.create<emitc::IncludeOp>( | ||
module.getLoc(), "ttnn/operations/eltwise/binary/binary.hpp", | ||
/*isStandard=*/false); | ||
builder.create<emitc::IncludeOp>( | ||
module.getLoc(), "ttnn/operations/core.hpp", /*isStandard=*/false); | ||
builder.create<emitc::IncludeOp>(module.getLoc(), | ||
"ttnn/operations/creation.hpp", | ||
/*isStandard=*/false); | ||
builder.create<emitc::IncludeOp>( | ||
module.getLoc(), | ||
"ttnn/operations/reduction/generic/generic_reductions.hpp", | ||
/*isStandard=*/false); | ||
} | ||
|
||
// TTNN -> EmitC | ||
// | ||
{ | ||
TTNNToEmitCTypeConverter typeConverter(&getContext()); | ||
RewritePatternSet patterns(&getContext()); | ||
|
||
// Func dialect handling | ||
// | ||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( | ||
patterns, typeConverter); | ||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { | ||
return typeConverter.isSignatureLegal(op.getFunctionType()) && | ||
typeConverter.isLegal(&op.getBody()); | ||
}); | ||
populateReturnOpTypeConversionPattern(patterns, typeConverter); | ||
target.addDynamicallyLegalOp<func::ReturnOp>( | ||
[&](func::ReturnOp op) { return typeConverter.isLegal(op); }); | ||
populateCallOpTypeConversionPattern(patterns, typeConverter); | ||
target.addDynamicallyLegalOp<func::CallOp>( | ||
[&](func::CallOp op) { return typeConverter.isLegal(op); }); | ||
|
||
// TTNN -> EmitC patterns | ||
// | ||
populateTTNNToEmitCPatterns(&getContext(), patterns, typeConverter); | ||
|
||
// Apply conversion | ||
// | ||
if (failed(applyFullConversion(getOperation(), target, | ||
std::move(patterns)))) { | ||
signalPassFailure(); | ||
return; | ||
} | ||
} | ||
}; | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir::tt { | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTTNNToEmitCPass() { | ||
return std::make_unique<ConvertTTNNToEmitCPass>(); | ||
} | ||
|
||
} // namespace mlir::tt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H | ||
#define TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H | ||
|
||
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" | ||
|
||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
class TTNNToEmitCTypeConverter : public TypeConverter { | ||
public: | ||
TTNNToEmitCTypeConverter(MLIRContext *ctx) { | ||
addConversion([](Type type) { return type; }); | ||
addConversion([ctx](mlir::tt::DeviceType type) -> Type { | ||
return emitc::OpaqueType::get(ctx, "ttnn::Device"); | ||
}); | ||
addConversion([ctx](TensorType type) -> Type { | ||
return emitc::OpaqueType::get(ctx, "ttnn::Tensor"); | ||
}); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
#endif // TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.