Skip to content

Commit

Permalink
Add support for minimum op.
Browse files Browse the repository at this point in the history
* Add end-to-end implementation of the ops
* Add stablehlo to ttir conversion
  • Loading branch information
mmanzoorTT committed Oct 29, 2024
1 parent 648739a commit 98823bf
Show file tree
Hide file tree
Showing 19 changed files with 106 additions and 26 deletions.
27 changes: 20 additions & 7 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,6 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTIR_EqualOp : TTIR_ElementwiseBinaryOp<"eq"> {
let summary = "Eltwise equal to.";
let description = [{
Expand Down Expand Up @@ -344,6 +337,26 @@ def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> {
}];
}

def TTIR_MinimumOp : TTIR_ElementwiseBinaryOp<"minimum"> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
in output tensor.

Example:
%lhs: [[3, 2, 7], [1, 4, 4]]
%rhs: [[1, 4, 2], [1, 2, 3]]
"ttir.minimum"(%lhs, %rhs, %out) -> %out: [[1, 2, 2], [1, 2, 3]]
}];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
TTIR_DPSOp<mnemonic, !listconcat(traits, [TTIR_GenericRegionOpInterface])> {

Expand Down
41 changes: 27 additions & 14 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,6 @@ def TTNN_DivOp : TTNN_ElementwiseBinaryOp<"div"> {
}];
}

def TTNN_SubtractOp : TTNN_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTNN_MultiplyOp : TTNN_ElementwiseBinaryOp<"multiply"> {
let summary = "Eltwise multiply.";
let description = [{
Eltwise multiply operation.
}];
}

def TTNN_EqualOp : TTNN_ElementwiseBinaryOp<"eq"> {
let summary = "Eltwise equal to.";
let description = [{
Expand Down Expand Up @@ -328,6 +314,33 @@ def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> {
}];
}

def TTNN_MinimumOp : TTNN_ElementwiseBinaryOp<"minimum"> {
let summary = "Eltwise minimum OP.";
let description = [{
Calculates minimum of input tensors' values element-wise and stores result
in output tensor.

Example:
%lhs: [[3, 2, 7], [1, 4, 4]]
%rhs: [[1, 4, 2], [1, 2, 3]]
"ttnn.minimum"(%lhs, %rhs, %out) -> %out: [[1, 2, 2], [1, 2, 3]]
}];
}

def TTNN_MultiplyOp : TTNN_ElementwiseBinaryOp<"multiply"> {
let summary = "Eltwise multiply.";
let description = [{
Eltwise multiply operation.
}];
}

def TTNN_SubtractOp : TTNN_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ enum EltwiseOpType: uint32 {
LogicalOr = 21,
LogicalNot = 22,
Cbrt = 23,
Minimum = 24,
}

table EltwiseOp {
Expand Down
12 changes: 7 additions & 5 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -879,15 +879,17 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,

patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::AddOp, mlir::tt::ttir::AddOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SubtractOp, mlir::tt::ttir::SubtractOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::MulOp, mlir::tt::ttir::MultiplyOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::DivOp, mlir::tt::ttir::DivOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::MaxOp, mlir::tt::ttir::MaximumOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::MinOp, mlir::tt::ttir::MinimumOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::MulOp, mlir::tt::ttir::MultiplyOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SubtractOp, mlir::tt::ttir::SubtractOp>>(typeConverter,
ctx);
}

void addReduceOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::LessEqualOp, ttnn::LessEqualOp>,
ElementwiseOpConversionPattern<ttir::LessThanOp, ttnn::LessThanOp>,
ElementwiseOpConversionPattern<ttir::MaximumOp, ttnn::MaximumOp>,
ElementwiseOpConversionPattern<ttir::MinimumOp, ttnn::MinimumOp>,
ElementwiseOpConversionPattern<ttir::NegOp, ttnn::NegOp>,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LessEqualOp>,
DefaultOpConversionPattern<ttnn::LessThanOp>,
DefaultOpConversionPattern<ttnn::MaximumOp>,
DefaultOpConversionPattern<ttnn::MinimumOp>,
DefaultOpConversionPattern<ttnn::DivOp>>(typeConverter, ctx);

// Tensor manipulation ops
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::LessThan;
} else if constexpr (std::is_same_v<EltwiseOp, MaximumOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Maximum;
} else if constexpr (std::is_same_v<EltwiseOp, MinimumOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Minimum;
} else if constexpr (std::is_same_v<EltwiseOp, ReluOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Relu;
} else if constexpr (std::is_same_v<EltwiseOp, SqrtOp>) {
Expand Down Expand Up @@ -568,6 +570,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, maximumOp),
debugString);
}
if (auto minimumOp = dyn_cast<MinimumOp>(op); minimumOp) {
return createOperation(cache, createEltwiseOp(cache, minimumOp),
debugString);
}
if (auto reluOp = dyn_cast<ReluOp>(op); reluOp) {
return createOperation(cache, createEltwiseOp(cache, reluOp), debugString);
}
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::maximum);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Minimum: {
runEltwiseBinaryCompositeOP(op, tensorPool, ::ttnn::minimum);
break;
}
default:
throw std::invalid_argument("Unsupported Eltwise Binary operation");
}
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/binary/minimum_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_eltwise_minimum attributes {} {
func.func public @test_minimum(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// CHECK: %[[C:.*]] = tensor.empty()
// CHECK-SAME: [[TENSOR:tensor<13x21x3xf32>]]
// CHECK: %[[C:.*]] = "ttir.minimum"
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%0 = stablehlo.minimum %arg0, %arg1 : tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.minimum"[[C:.*]]
%1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6
return %1 : tensor<64x128xf32>
}

func.func @minimum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"
// CHECK-SAME: [[TENSOR:tensor<64x128xf32,]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.minimum"
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
Expand Down

0 comments on commit 98823bf

Please sign in to comment.