diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 9e192a527..5558d4e43 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -179,6 +179,13 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> { }]; } +def TTIR_ReluOp : TTIR_ElementwiseOp<"relu"> { + let summary = "Eltwise ReLU."; + let description = [{ + Eltwise ReLU operation. + }]; +} + // ANCHOR: adding_an_op_matmul_ttir def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let summary = "Matrix multiply operation."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 80508e86a..97c4a61aa 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -117,6 +117,13 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> { }]; } +def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> { + let summary = "Eltwise ReLU."; + let description = [{ + Eltwise ReLU operation. + }]; +} + // ANCHOR: adding_an_op_matmul_ttnn def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index f5868aa35..adc7f158a 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -28,6 +28,7 @@ enum EltwiseOpType: uint32 { Add = 0, Multiply = 1, Subtract = 2, + Relu = 3, } table EltwiseOp { diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 6835e61ba..705952603 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -16,6 +16,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h" #include "ttmlir/Dialect/TTIR/Passes.h" @@ -114,6 +115,9 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern { } else if constexpr (std::is_same::value) { kernelName = "subtract"; kernelKind = "eltwise"; + } else if constexpr (std::is_same::value) { + kernelName = "relu"; + kernelKind = "eltwise"; } else { return rewriter.notifyMatchFailure(op, "Unsupported Tosa operation for TTIR"); @@ -267,7 +271,8 @@ class TTIRGeneric : public impl::TTIRGenericBase { patterns.add, TTIRNamedToKernelRewriter, - TTIRNamedToKernelRewriter>(&getContext()); + TTIRNamedToKernelRewriter, + TTIRNamedToKernelRewriter>(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { signalPassFailure(); @@ -588,14 +593,14 @@ class TTIRLayout : public impl::TTIRLayoutBase { } { RewritePatternSet patterns(&getContext()); - patterns - .add, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( - &getContext()); + patterns.add< + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( + &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { signalPassFailure(); diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index cdf7c958f..c27c1a39f 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -147,6 +147,7 @@ class ConvertTTIRToTTNN .add, TTIRToTTNNOpRewriter, TTIRToTTNNOpRewriter, + TTIRToTTNNOpRewriter, TTIRToTTNNBinaryOpRewriter, TTIRToTTNNReductionOpRewriter, TensorEmptyToFullRewriter>(&getContext()); diff --git a/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp b/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp index 2d72ea126..beb7c6909 100644 --- a/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp +++ b/lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp @@ -112,6 +112,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Multiply; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Subtract; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Relu; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -176,6 +178,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createEltwiseOp(cache, subtractOp), debugString); } + if (auto reluOp = dyn_cast(op); reluOp) { + return createOperation(cache, createEltwiseOp(cache, reluOp), debugString); + } if (auto matmulOp = dyn_cast(op); matmulOp) { return createOperation(cache, createOp(cache, matmulOp), debugString); } diff --git a/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp b/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp index a332af49a..b2c17d97e 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp @@ -143,6 +143,7 @@ class ConvertTTNNToEmitC TTNNToEmitCOpaqueRewriter, TTNNToEmitCOpaqueRewriter, TTNNToEmitCOpaqueRewriter, + TTNNToEmitCOpaqueRewriter, TTNNToEmitCOpaqueRewriter, TTNNToEmitCOpaqueRewriter, TTNNToEmitCOpaqueRewriter>(&getContext()); diff --git a/test/ttmlir/Dialect/TTNN/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/simple_relu.mlir new file mode 100644 index 000000000..5c9d7643f --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_relu.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {torch.debug_module_name = "_lambda", tt.system_desc = #tt.system_desc<[{arch = , grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576}], [0], [], [<0, 0, 0, 0>]>} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: "ttnn.close_device"[[C:.*]] + return %1 : tensor<64x128xf32> + } +}