diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index 438f55faf9..26bcf1a192 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -14,7 +14,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h.inc" +#include "TTIROpsInterfaces.h" #define GET_OP_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROps.h.inc" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index c9f1b34a71..69e329bc07 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -116,7 +116,7 @@ def TTIR_DeallocOp : TTIR_Op<"dealloc"> { //===----------------------------------------------------------------------===// class TTIR_ElementwiseOp traits = []> : - TTIR_DPSOp { + TTIR_DPSOp { let arguments = (ins Variadic:$inputs, Variadic:$outputs, @@ -124,12 +124,36 @@ class TTIR_ElementwiseOp traits = []> : let results = (outs Variadic:$results); } +class TTIR_ElementwiseUnaryOp traits = []> : + TTIR_ElementwiseOp { + let summary = "Eltwise unary op."; + let description = [{ + Eltwise unary op. + }]; + + let builders = + [ + OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints), + [{ + build($_builder, $_state, {out.getType()}, in, out, operand_constraints); + }]> + ]; +} + class TTIR_ElementwiseBinaryOp traits = []> : TTIR_ElementwiseOp { let summary = "Eltwise binary op."; let description = [{ Eltwise binary op. }]; + + let builders = + [ + OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints), + [{ + build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints); + }]> + ]; } def TTIR_AddOp : TTIR_ElementwiseBinaryOp<"add"> { @@ -160,6 +184,13 @@ def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> { }]; } +def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> { + let summary = "Eltwise ReLU."; + let description = [{ + Eltwise ReLU operation. + }]; +} + class TTIR_ReductionOp traits = []> : TTIR_DPSOp { let summary = "Reduction op."; let description = [{ @@ -186,13 +217,6 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> { }]; } -def TTIR_ReluOp : TTIR_ElementwiseOp<"relu"> { - let summary = "Eltwise ReLU."; - let description = [{ - Eltwise ReLU operation. - }]; -} - def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> { let summary = "Softmax operation."; let description = [{ diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h new file mode 100644 index 0000000000..aad064211f --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTIR_IR_TTIROPSINTERFACES_H +#define TTMLIR_DIALECT_TTIR_IR_TTIROPSINTERFACES_H + +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +namespace mlir { +namespace tt { +namespace ttir { +namespace detail { +mlir::LogicalResult verifyElementwiseOp(mlir::Operation *op); +} // namespace detail +} // namespace ttir +} // namespace tt +} // namespace mlir + +#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h.inc" + +#endif diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index 1d42fbb841..6b6d2dda1c 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -24,4 +24,12 @@ def TTIROpInterface : OpInterface<"TTIROpInterface"> { ]; } +def TTIR_ElementwiseOpInterface : OpInterface<"TTIR_ElementwiseOpInterface"> { + let cppNamespace = "::mlir::tt::ttir"; + + let verify = [{ + return detail::verifyElementwiseOp($_op); + }]; +} + #endif diff --git a/lib/Dialect/TTIR/IR/CMakeLists.txt b/lib/Dialect/TTIR/IR/CMakeLists.txt index fd2504f6dd..2dad1a49a2 100644 --- a/lib/Dialect/TTIR/IR/CMakeLists.txt +++ b/lib/Dialect/TTIR/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTTIRDialect TTIRDialect.cpp TTIROps.cpp + TTIROpsInterfaces.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir diff --git a/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp new file mode 100644 index 0000000000..7dfa5c083c --- /dev/null +++ b/lib/Dialect/TTIR/IR/TTIROpsInterfaces.cpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +#include + +mlir::LogicalResult +mlir::tt::ttir::detail::verifyElementwiseOp(mlir::Operation *op) { + // Currently, elementwise trait already performs the basic verification. + // Let this be a placeholder for future extensions. + + return success(); +}