Skip to content

Commit

Permalink
Relu op support in tt-mlir (#114)
Browse files Browse the repository at this point in the history
Fixes #78.
  • Loading branch information
rpavlovicTT authored Jul 9, 2024
1 parent 023703b commit 521be7a
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 9 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
}];
}

def TTIR_ReluOp : TTIR_ElementwiseOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
}];
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
}];
}

def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Eltwise ReLU operation.
}];
}

// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum EltwiseOpType: uint32 {
Add = 0,
Multiply = 1,
Subtract = 2,
Relu = 3,
}

table EltwiseOp {
Expand Down
23 changes: 14 additions & 9 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "ttmlir/Dialect/TTIR/Analysis/GridAnalysis.h"
#include "ttmlir/Dialect/TTIR/Passes.h"
Expand Down Expand Up @@ -114,6 +115,9 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern<TTIROp> {
} else if constexpr (std::is_same<TTIROp, ttir::SubtractOp>::value) {
kernelName = "subtract";
kernelKind = "eltwise";
} else if constexpr (std::is_same<TTIROp, ttir::ReluOp>::value) {
kernelName = "relu";
kernelKind = "eltwise";
} else {
return rewriter.notifyMatchFailure(op,
"Unsupported Tosa operation for TTIR");
Expand Down Expand Up @@ -267,7 +271,8 @@ class TTIRGeneric : public impl::TTIRGenericBase<TTIRGeneric> {
patterns.add<TTIRLinalgGenericRewriter, TTIRKernelGenericRewriter,
TTIRNamedToKernelRewriter<AddOp>,
TTIRNamedToKernelRewriter<MultiplyOp>,
TTIRNamedToKernelRewriter<SubtractOp>>(&getContext());
TTIRNamedToKernelRewriter<SubtractOp>,
TTIRNamedToKernelRewriter<ReluOp>>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
Expand Down Expand Up @@ -588,14 +593,14 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
}
{
RewritePatternSet patterns(&getContext());
patterns
.add<TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<MatmulOp>,
TTIRLayoutOperandsRewriter<SumOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
patterns.add<
TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<ReluOp>, TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<MatmulOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class ConvertTTIRToTTNN
.add<TTIRToTTNNLayoutRewriter, TTIRToTTNNOpRewriter<ttir::AddOp, AddOp>,
TTIRToTTNNOpRewriter<ttir::MultiplyOp, MultiplyOp>,
TTIRToTTNNOpRewriter<ttir::SubtractOp, SubtractOp>,
TTIRToTTNNOpRewriter<ttir::ReluOp, ReluOp>,
TTIRToTTNNBinaryOpRewriter<ttir::MatmulOp, MatmulOp>,
TTIRToTTNNReductionOpRewriter<ttir::SumOp, SumOp>,
TensorEmptyToFullRewriter>(&getContext());
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Multiply;
} else if constexpr (std::is_same_v<EltwiseOp, SubtractOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Subtract;
} else if constexpr (std::is_same_v<EltwiseOp, ReluOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Relu;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -176,6 +178,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, subtractOp),
debugString);
}
if (auto reluOp = dyn_cast<ReluOp>(op); reluOp) {
return createOperation(cache, createEltwiseOp(cache, reluOp), debugString);
}
if (auto matmulOp = dyn_cast<MatmulOp>(op); matmulOp) {
return createOperation(cache, createOp(cache, matmulOp), debugString);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class ConvertTTNNToEmitC
TTNNToEmitCOpaqueRewriter<ToMemoryConfigOp>,
TTNNToEmitCOpaqueRewriter<MultiplyOp>,
TTNNToEmitCOpaqueRewriter<SubtractOp>,
TTNNToEmitCOpaqueRewriter<ReluOp>,
TTNNToEmitCOpaqueRewriter<MatmulOp>,
TTNNToEmitCOpaqueRewriter<SumOp>,
TTNNToEmitCOpaqueRewriter<CloseDeviceOp>>(&getContext());
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_relu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {torch.debug_module_name = "_lambda", tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]]
%1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<64x128xf32>
}
}

0 comments on commit 521be7a

Please sign in to comment.