Skip to content

Commit

Permalink
Add negate op (#675)
Browse files Browse the repository at this point in the history
Add e2e negate op along with conversion for stablehlo.
  • Loading branch information
mmanzoorTT authored Sep 13, 2024
1 parent f98d0dd commit 0153683
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 1 deletion.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ def TTIR_SqrtOp : TTIR_ElementwiseUnaryOp<"sqrt"> {
}];
}

def TTIR_NegOp: TTIR_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate op.";
let description = [{
Eltwise negate operation.
}];
}

def TTIR_ReciprocalOp : TTIR_ElementwiseUnaryOp<"reciprocal"> {
let summary = "Eltwise reciprocal.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def TTNN_SqrtOp : TTNN_ElementwiseUnaryOp<"sqrt"> {
}];
}

def TTNN_NegOp : TTNN_ElementwiseUnaryOp<"neg"> {
let summary = "Eltwise negate.";
let description = [{
Eltwise negate operation.
}];
}

def TTNN_ReciprocalOp : TTNN_ElementwiseUnaryOp<"reciprocal"> {
let summary = "Eltwise reciprocal.";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ enum EltwiseOpType: uint32 {
Reciprocal = 8,
Exp = 9,
Maximum = 10,
Abs = 11
Abs = 11,
Neg = 12,
}

table EltwiseOp {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
mlir::stablehlo::AbsOp, mlir::tt::ttir::AbsOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::ExpOp, mlir::tt::ttir::ExpOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::NegOp, mlir::tt::ttir::NegOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
ElementwiseOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
ElementwiseOpConversionPattern<ttir::MaximumOp, ttnn::MaximumOp>,
ElementwiseOpConversionPattern<ttir::NegOp, ttnn::NegOp>,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Eltwise unary ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::SigmoidOp>,
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ struct ConvertTosaToTTIRPass
patterns.add<
TosaToTTIROpConversionPattern<tosa::MulOp, mlir::tt::ttir::MultiplyOp>>(
typeConverter, &getContext());
patterns.add<
TosaToTTIROpConversionPattern<tosa::NegateOp, mlir::tt::ttir::NegOp>>(
typeConverter, &getContext());
patterns.add<
TosaToTTIROpConversionPattern<tosa::SubOp, mlir::tt::ttir::SubtractOp>>(
typeConverter, &getContext());
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Add;
} else if constexpr (std::is_same_v<EltwiseOp, MultiplyOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Multiply;
} else if constexpr (std::is_same_v<EltwiseOp, NegOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Neg;
} else if constexpr (std::is_same_v<EltwiseOp, SubtractOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Subtract;
} else if constexpr (std::is_same_v<EltwiseOp, GreaterEqualOp>) {
Expand Down Expand Up @@ -428,6 +430,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, multiplyOp),
debugString);
}
if (auto negOp = dyn_cast<NegOp>(op); negOp) {
return createOperation(cache, createEltwiseOp(cache, negOp), debugString);
}
if (auto subtractOp = dyn_cast<SubtractOp>(op); subtractOp) {
return createOperation(cache, createEltwiseOp(cache, subtractOp),
debugString);
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,10 @@ static void run(::tt::target::ttnn::EltwiseOp const *op,
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Neg: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::neg);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Relu: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::relu);
break;
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/negate_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_neg attributes {} {
func.func public @test_neg(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.negate %arg0 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.neg"[[C:.*]]
return %0 : tensor<13x21x3xf32>
}
}
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/negate/simple_neg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]]
%1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}

0 comments on commit 0153683

Please sign in to comment.