Skip to content

Commit

Permalink
Change ttnn->emitc transform to use dialect conversion API (#190)
Browse files Browse the repository at this point in the history
* Change ttnn->emitc transform to use dialect conversion API

* pr fixes v1

* update docs

* fix ttnn-serialize-to-binary pipeline
  • Loading branch information
svuckovicTT authored Jul 18, 2024
1 parent 2fcb239 commit 9f3982e
Show file tree
Hide file tree
Showing 19 changed files with 353 additions and 180 deletions.
4 changes: 2 additions & 2 deletions docs/src/adding-an-op.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ TTIRToTTNNBinaryOpRewriter<ttir::MatmulOp, MatmulOp>
```

We also need to add this op to the C++ emitter,
`lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp` see
`TTNNToEmitCOpaqueRewriter<MatmulOp>`.
`lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp` see
`populateTTNNToEmitCPatterns(...)`.

## 4. Add a unit test for the Op

Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Target)
add_dependencies(mlir-headers FBS_GENERATION)
3 changes: 3 additions & 0 deletions include/ttmlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name TTMLIRConversion)
add_public_tablegen_target(TTMLIRConversionPassIncGen)
20 changes: 20 additions & 0 deletions include/ttmlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_PASSES_H
#define TTMLIR_CONVERSION_PASSES_H

#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"

namespace mlir::tt {

#define GEN_PASS_REGISTRATION
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_PASSES_H
16 changes: 16 additions & 0 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_PASSES
#define TTMLIR_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def ConvertTTNNToEmitC : Pass<"convert-ttnn-to-emitc", "::mlir::func::FuncOp"> {
let summary = "Convert TTNN dialect to EmitC dialect.";
let constructor = "createConvertTTNNToEmitCPass()";
let dependentDialects = ["mlir::emitc::EmitCDialect", "mlir::tt::ttnn::TTNNDialect"];
}

#endif // TTMLIR_CONVERSION_PASSES
18 changes: 18 additions & 0 deletions include/ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H
#define TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::tt {

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTTNNToEmitCPass();

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_TTNNTOEMITC_TTNNTOEMITC_H
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTNN/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ def ConvertTTIRToTTNN: Pass<"convert-ttir-to-ttnn", "::mlir::ModuleOp"> {
}];
}

def ConvertTTNNToEmitC: Pass<"convert-ttnn-to-emitc", "::mlir::ModuleOp"> {
let summary = "";
let description = [{
todo
}];
}

def TTNNSerializeToBinary: Pass<"ttnn-serialize-to-binary", "::mlir::ModuleOp"> {
let summary = "";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)

add_mlir_library(TTMLIR STATIC RegisterAll.cpp
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(TTNNToEmitC)
25 changes: 25 additions & 0 deletions lib/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_LIB_CONVERSION_PASSDETAIL_H
#define TTMLIR_LIB_CONVERSION_PASSDETAIL_H

#include "ttmlir/Dialect/TTNN/IR/TTNN.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir::emitc {
class EmitCDialect;
}

namespace mlir::tt::ttnn {

#define GEN_PASS_CLASSES
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttnn

#endif // TTMLIR_LIB_CONVERSION_PASSDETAIL_H
14 changes: 14 additions & 0 deletions lib/Conversion/TTNNToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
add_mlir_library(TTMLIRTTNNToEmitC
TTNNToEmitC.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TTNNToEmitC

DEPENDS
TTMLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTransformUtils
)
70 changes: 70 additions & 0 deletions lib/Conversion/TTNNToEmitC/PopulatePatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H
#define TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {

template <typename SrcOp, typename Adaptor = typename SrcOp::Adaptor>
class DefaultOpConversionPattern : public OpConversionPattern<SrcOp> {
using OpConversionPattern<SrcOp>::OpConversionPattern;

public:
// Default op conversion pattern, used to convert most ops
//
DefaultOpConversionPattern(MLIRContext *ctx)
: OpConversionPattern<SrcOp>(ctx) {}

DefaultOpConversionPattern(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern<SrcOp>(typeConverter, context, benefit) {}

// Converts op name by removing the dialect prefix ("ttnn.") and replacing
// with namespace prefix ("ttnn::")
//
std::string convertOpName(SrcOp op) const {
auto name = op.getOperationName();
assert(
name.starts_with("ttnn.") &&
"DefaultOpConversionPattern only supports ops from the TTNN dialect");

return name.str().replace(0, 5, "ttnn::");
}

LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int numReturnTypes = srcOp->getResultTypes().size();
assert(numReturnTypes <= 1 &&
"DefaultOpConversionPattern does not support multiple return types");

// If srcOp has a return type, cast it before converting
//
if (numReturnTypes == 1) {
auto resultTy = cast<emitc::OpaqueType>(
this->getTypeConverter()->convertType(srcOp->getResult(0).getType()));
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, resultTy, convertOpName(srcOp), nullptr, nullptr,
adaptor.getOperands());
} else {
// No return type, only convert the op
//
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, srcOp->getResultTypes(), convertOpName(srcOp), nullptr,
nullptr, adaptor.getOperands());
}

return success();
}
};

} // namespace

#endif // TTMLIR_LIB_CONVERSION_TTNNTOEMITC_POPULATEPATTERNS_H
137 changes: 137 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h"

#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "TypeConverter.h"

#include "ttmlir/Dialect/TT/IR/TTOpsDialect.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
using namespace mlir::tt;

namespace {

void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
mlir::RewritePatternSet &patterns,
TypeConverter &typeConverter) {
// Device ops
//
patterns.add<DefaultOpConversionPattern<ttnn::OpenDeviceOp>>(typeConverter,
ctx, true);
patterns.add<DefaultOpConversionPattern<ttnn::CloseDeviceOp>>(typeConverter,
ctx);

// Memory ops
//
patterns.add<DefaultOpConversionPattern<ttnn::ToMemoryConfigOp>>(
typeConverter, ctx);

// Tensor ops
//
patterns.add<DefaultOpConversionPattern<ttnn::FullOp>>(typeConverter, ctx);

// Math ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AddOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SubtractOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MultiplyOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MatmulOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReluOp>>(typeConverter, ctx);
}

struct ConvertTTNNToEmitCPass
: public ttnn::ConvertTTNNToEmitCBase<ConvertTTNNToEmitCPass> {
void runOnOperation() override {
mlir::ConversionTarget target(getContext());

target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<emitc::EmitCDialect>();
target.addIllegalDialect<ttnn::TTNNDialect>();

// Add header imports to front of module
//
{
auto module = getOperation();
OpBuilder builder(module);

builder.create<emitc::IncludeOp>(module.getLoc(), "ttnn/device.h",
/*isStandard=*/false);
builder.create<emitc::IncludeOp>(
module.getLoc(), "ttnn/operations/eltwise/binary/binary.hpp",
/*isStandard=*/false);
builder.create<emitc::IncludeOp>(
module.getLoc(), "ttnn/operations/core.hpp", /*isStandard=*/false);
builder.create<emitc::IncludeOp>(module.getLoc(),
"ttnn/operations/creation.hpp",
/*isStandard=*/false);
builder.create<emitc::IncludeOp>(
module.getLoc(),
"ttnn/operations/reduction/generic/generic_reductions.hpp",
/*isStandard=*/false);
}

// TTNN -> EmitC
//
{
TTNNToEmitCTypeConverter typeConverter(&getContext());
RewritePatternSet patterns(&getContext());

// Func dialect handling
//
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::ReturnOp>(
[&](func::ReturnOp op) { return typeConverter.isLegal(op); });
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });

// TTNN -> EmitC patterns
//
populateTTNNToEmitCPatterns(&getContext(), patterns, typeConverter);

// Apply conversion
//
if (failed(applyFullConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
return;
}
}
};
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTTNNToEmitCPass() {
return std::make_unique<ConvertTTNNToEmitCPass>();
}

} // namespace mlir::tt
32 changes: 32 additions & 0 deletions lib/Conversion/TTNNToEmitC/TypeConverter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H
#define TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {

class TTNNToEmitCTypeConverter : public TypeConverter {
public:
TTNNToEmitCTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](mlir::tt::DeviceType type) -> Type {
return emitc::OpaqueType::get(ctx, "ttnn::Device");
});
addConversion([ctx](TensorType type) -> Type {
return emitc::OpaqueType::get(ctx, "ttnn::Tensor");
});
}
};

} // namespace

#endif // TTMLIR_LIB_CONVERSION_TTNNTOEMITC_TYPECONVERTER_H
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ add_mlir_dialect_library(MLIRTTNNTransforms
)

target_include_directories(MLIRTTNNTransforms PUBLIC ${PROJECT_BINARY_DIR}/include/ttmlir/Target/Common)
target_link_libraries(MLIRTTNNTransforms PRIVATE TTMLIRTTNNToEmitC)
Loading

0 comments on commit 9f3982e

Please sign in to comment.