Skip to content

Commit

Permalink
TTNNOpsBackendInterface definition (#1131)
Browse files Browse the repository at this point in the history
* TTNNOpsBackendInterface definition.

This is a ground work to enable backend interface in the TTNN ops dialect, it contains TTNNOpsBackendInterface definition, default implementation for eltwise ops with a stub override for a relu op.
  • Loading branch information
mbezuljTT authored Nov 4, 2024
1 parent 6100428 commit e6c60fd
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_dialect(TTNNOps ttnn)
add_mlir_doc(TTNNBase TTNNDialect src/autogen/md/Dialect/ -gen-dialect-doc)
add_mlir_doc(TTNNOps TTNNOp src/autogen/md/Dialect/ -gen-op-doc)

add_mlir_interface(TTNNOpsBackendInterfaces)

set(LLVM_TARGET_DEFINITIONS TTNNOpsEnums.td)
mlir_tablegen(TTNNOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTNNOpsEnums.cpp.inc -gen-enum-defs)
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNDIALECT_TD

include "mlir/IR/OpBase.td"
include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td"

//===----------------------------------------------------------------------===//
// TTNN dialect definition.
Expand Down Expand Up @@ -43,6 +44,6 @@ def TTNN_Dialect : Dialect {
//===----------------------------------------------------------------------===//

class TTNN_Op<string mnemonic, list<Trait> traits = []> :
Op<TTNN_Dialect, mnemonic, traits>;
Op<TTNN_Dialect, mnemonic, !listconcat(traits, [TTNNOpBackendInterface])>;

#endif
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

#define GET_OP_CLASSES
Expand Down
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> {
}];
}

def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> {
def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu",
[DeclareOpInterfaceMethods<TTNNOpBackendInterface, ["getOpPerfCycles", "getOpL1Usage", "isOpLegal"]>]
> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
Expand Down
50 changes: 50 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
#define TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD

include "mlir/IR/OpBase.td"

def TTNNOpBackendInterface : OpInterface<"TTNNOpBackend"> {
let description = [{
Interface to access a registered method to infer the return types for an
operation that can be used during type inference.
}];
let cppNamespace = "::mlir::tt::ttnn";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpPerfCycles",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return std::numeric_limits<size_t>::max();"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"size_t",
/*methodName=*/"getOpL1Usage",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return 0;"
>,
InterfaceMethod<
/*desc=*/[{
Return the op kernel estimate in clock cycles.
}],
/*retTy=*/"bool",
/*methodName=*/"isOpLegal",
/*args=*/(ins "const std::vector<tt::LayoutAttr>&":$input_layouts, "const tt::LayoutAttr&":$output_layout),
/*methodBody=*/"",
/*defaultImplementation=*/"return true;"
>,
];
}

#endif // TTMLIR_TTMLIR_DIALECT_TTNN_TTNNOPSPERFINTERFACES_TD
8 changes: 8 additions & 0 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include <mlir/Interfaces/DestinationStyleOpInterface.h>
#include <mlir/Support/LLVM.h>
#include <unordered_set>
Expand Down Expand Up @@ -503,6 +504,13 @@ bool ShardSolver::checkShardCompatible(
// TEMP : Dummy mock implementation, will be replaced.
//

if (TTNNOpBackend backend = dyn_cast<TTNNOpBackend>(consumerOp)) {
if (false ==
backend.isOpLegal(std::vector{producerLayout}, consumerLayout)) {
return false;
}
}

// May need to fetch other inputs for consumerOp(weights/join node).
//

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTTNNDialect
TTNNDialect.cpp
TTNNOps.cpp
TTNNOpsBackendInterfaces.cpp
TTNNOpsTypes.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
34 changes: 34 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsBackendInterfaces.cpp.inc"

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ReluOp
//===----------------------------------------------------------------------===//

// // Relu backend interface
size_t ReluOp::getOpPerfCycles(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op cycles.
return 5;
}

size_t ReluOp::getOpL1Usage(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom estimate for relu op L1 usage.
return 10;
}

bool ReluOp::isOpLegal(const std::vector<tt::LayoutAttr> &input_layouts,
const tt::LayoutAttr &output_layout) {
// Implement a custom check for relu op legality.
return true;
}

} // namespace mlir::tt::ttnn

0 comments on commit e6c60fd

Please sign in to comment.