Skip to content

Commit

Permalink
Add absolute op
Browse files Browse the repository at this point in the history
Add absolute op end to end along with required conversion for stablehlo.
  • Loading branch information
mmanzoorTT committed Sep 12, 2024
1 parent e51f5f7 commit 98ffefa
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 2 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> {
}];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
let summary = "Eltwise absolute op.";
let description = [{
Eltwise absolute operation.
}];
}

def TTIR_SqrtOp : TTIR_ElementwiseUnaryOp<"sqrt"> {
let summary = "Eltwise square root.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let summary = "Eltwise absolute.";
let description = [{
Eltwise absolute operation.
}];
}

def TTNN_SqrtOp : TTNN_ElementwiseUnaryOp<"sqrt"> {
let summary = "Eltwise sqrt.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum EltwiseOpType: uint32 {
Reciprocal = 8,
Exp = 9,
Maximum = 10,
Abs = 11
}

table EltwiseOp {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::AbsOp, mlir::tt::ttir::AbsOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::ExpOp, mlir::tt::ttir::ExpOp>>(typeConverter, ctx);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
patterns
.add<TensorEmptyConversionPattern,
ToLayoutOpConversionPattern,
ElementwiseOpConversionPattern<ttir::AbsOp, ttnn::AbsOp>,
ElementwiseOpConversionPattern<ttir::AddOp, ttnn::AddOp>,
ElementwiseOpConversionPattern<ttir::SubtractOp, ttnn::SubtractOp>,
ElementwiseOpConversionPattern<ttir::MultiplyOp, ttnn::MultiplyOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,

// Eltwise unary ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReluOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SqrtOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SigmoidOp>>(typeConverter, ctx);
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TosaToTTIR/TosaToTTIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ struct ConvertTosaToTTIRPass
RewritePatternSet patterns(&getContext());

// Add conversion patterns.
patterns
.add<TosaToTTIROpConversionPattern<tosa::AbsOp, mlir::tt::ttir::AbsOp>>(
typeConverter, &getContext());
patterns
.add<TosaToTTIROpConversionPattern<tosa::AddOp, mlir::tt::ttir::AddOp>>(
typeConverter, &getContext());
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern<TTIROpTy> {
StringRef kernelName;
StringRef kernelKind;
if constexpr (std::is_same<TTIROpTy, ttir::MultiplyOp>::value) {
kernelName = "mulitply";
kernelName = "multiply";
kernelKind = "eltwise";
} else if constexpr (std::is_same<TTIROpTy, ttir::AddOp>::value) {
kernelName = "add";
Expand Down
7 changes: 6 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
::tt::target::ttnn::EltwiseOpType type;
if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
if constexpr (std::is_same_v<EltwiseOp, AbsOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Abs;
} else if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Add;
} else if constexpr (std::is_same_v<EltwiseOp, MultiplyOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Multiply;
Expand Down Expand Up @@ -319,6 +321,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto fullOp = dyn_cast<FullOp>(op); fullOp) {
return createOperation(cache, createOp(cache, fullOp), debugString);
}
if (auto absOp = dyn_cast<AbsOp>(op); absOp) {
return createOperation(cache, createEltwiseOp(cache, absOp), debugString);
}
if (auto addOp = dyn_cast<AddOp>(op); addOp) {
return createOperation(cache, createEltwiseOp(cache, addOp), debugString);
}
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ static void run(::tt::target::ttnn::EltwiseOp const *op,
break;
}
/* Eltwise Unary */
case ::tt::target::ttnn::EltwiseOpType::Abs: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Relu: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::relu);
break;
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/absolute_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_abs attributes {} {
func.func public @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.abs %arg0 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.abs"[[C:.*]]
return %0 : tensor<13x21x3xf32>
}
}
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/abs/simple_abs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.abs"[[C:.*]]
%1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}

0 comments on commit 98ffefa

Please sign in to comment.