Skip to content

Commit

Permalink
Eltwise interface and builders (#214)
Browse files Browse the repository at this point in the history
Fixes #110
  • Loading branch information
rpavlovicTT authored Jul 23, 2024
1 parent f2c8b0b commit 2574010
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h.inc"
#include "TTIROpsInterfaces.h"

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h.inc"
Expand Down
40 changes: 32 additions & 8 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,44 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> {
//===----------------------------------------------------------------------===//

class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [Elementwise, AttrSizedOperandSegments])> {
TTIR_DPSOp<mnemonic, !listconcat(traits, [Elementwise, AttrSizedOperandSegments, TTIR_ElementwiseOpInterface])> {

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs Variadic<AnyRankedTensor>:$results);
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
}];

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, in, out, operand_constraints);
}]>
];
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise binary op.";
let description = [{
Eltwise binary op.
}];

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints);
}]>
];
}

def TTIR_AddOp : TTIR_ElementwiseBinaryOp<"add"> {
Expand Down Expand Up @@ -160,6 +184,13 @@ def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> {
}];
}

def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> : TTIR_DPSOp<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Expand All @@ -186,13 +217,6 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
}];
}

def TTIR_ReluOp : TTIR_ElementwiseOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
}];
}

def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {
let summary = "Softmax operation.";
let description = [{
Expand Down
22 changes: 22 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPSINTERFACES_H
#define TTMLIR_DIALECT_TTIR_IR_TTIROPSINTERFACES_H

#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

namespace mlir {
namespace tt {
namespace ttir {
namespace detail {
mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op);
} // namespace detail
} // namespace ttir
} // namespace tt
} // namespace mlir

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h.inc"

#endif
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,12 @@ def TTIROpInterface : OpInterface<"TTIROpInterface"> {
];
}

def TTIR_ElementwiseOpInterface : OpInterface<"TTIR_ElementwiseOpInterface"> {
let cppNamespace = "::mlir::tt::ttir";

let verify = [{
return detail::verifyElementwiseOp($_op);
}];
}

#endif
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTTIRDialect
TTIRDialect.cpp
TTIROps.cpp
TTIROpsInterfaces.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"

#include <mlir/Support/LogicalResult.h>

mlir::LogicalResult
mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) {
// Currently, elementwise trait already performs the basic verification.
// Let this be a placeholder for future extensions.

return success();
}

0 comments on commit 2574010

Please sign in to comment.