Skip to content

Commit

Permalink
Add E2E implementation of reduce minimum op along with StableHLO conv…
Browse files Browse the repository at this point in the history
…ersion.
  • Loading branch information
mmanzoorTT committed Jan 16, 2025
1 parent 757ef0b commit 20bbd5c
Show file tree
Hide file tree
Showing 14 changed files with 359 additions and 2 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,13 @@ def TTIR_MaxOp : TTIR_ReductionOp<"max"> {
}];
}

def TTIR_MinOp : TTIR_ReductionOp<"min"> {
let summary = "Min reduction op.";
let description = [{
Min reduction op.
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,13 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> {
}];
}

def TTNN_MinOp : TTNN_ReductionOp<"min"> {
let summary = "Min reduction op.";
let description = [{
Min reduction op.
}];
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ enum ReductionOpType: uint32 {
Sum,
Mean,
Max,
Min,
}

table ReductionOp {
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::MaxOp>(srcOp, adaptor,
rewriter);
}
if (mlir::isa<mlir::stablehlo::MinOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::MinOp>(srcOp, adaptor,
rewriter);
}

return failure();
}
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ReductionOpConversionPattern<ttir::MinOp, ttnn::MinOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>,
DefaultOpConversionPattern<ttnn::MeanOp>,
DefaultOpConversionPattern<ttnn::MaxOp>>(typeConverter, ctx);
DefaultOpConversionPattern<ttnn::MaxOp>,
DefaultOpConversionPattern<ttnn::MinOp>>(typeConverter, ctx);

// Conv ops
//
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2216,3 +2216,19 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce MinOp
//===----------------------------------------------------------------------===//

// MinOp kernel builder.
void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "min");
}

// MinOp verification.
::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
9 changes: 9 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1505,4 +1505,13 @@ ::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce MinOp
//===----------------------------------------------------------------------===//

// MinOp verification.
::mlir::LogicalResult MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

} // namespace mlir::tt::ttnn
6 changes: 5 additions & 1 deletion lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,16 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MinOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::SumOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MaxOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MeanOp>>(&getContext());
ttnn::MeanOp>,
workarounds::decomposition::ReduceOpsAllDimsRewritePattern<
ttnn::MinOp>>(&getContext());

runRewritePatterns(std::move(patterns),
GreedyRewriteConfig::kNoLimit /*maxIterations*/);
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
type = ::tt::target::ttnn::ReductionOpType::Mean;
} else if constexpr (std::is_same_v<ReductionOp, MaxOp>) {
type = ::tt::target::ttnn::ReductionOpType::Max;
} else if constexpr (std::is_same_v<ReductionOp, MinOp>) {
type = ::tt::target::ttnn::ReductionOpType::Min;
} else {
llvm_unreachable("unhandled ReductionOp");
}
Expand Down Expand Up @@ -1150,6 +1152,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createReductionOp(cache, maxOp), debugString,
locInfo);
}
if (auto minOp = dyn_cast<MinOp>(op); minOp) {
return createOperation(cache, createReductionOp(cache, minOp), debugString,
locInfo);
}
if (auto embeddingOp = dyn_cast<EmbeddingOp>(op); embeddingOp) {
return createOperation(cache, createEmbeddingOp(cache, embeddingOp),
debugString, locInfo);
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/operations/reduction/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ void run(const ::tt::target::ttnn::ReductionOp *op, ProgramContext &context) {
runReductionOp(op, tensorPool, ::ttnn::max);
break;
}
case ::tt::target::ttnn::ReductionOpType::Min: {
runReductionOp(op, tensorPool, ::ttnn::min);
break;
}
}
}
} // namespace tt::runtime::ttnn::operations::reduction
113 changes: 113 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_min_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_reduce_minimum attributes {} {
func.func public @test_reduce_minimum_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32x4xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<128x32x4xf32>
return %0 : tensor<128x32x4xf32>
}

func.func public @test_reduce_minimum_4to2dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1, 3] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<128x32xf32>
return %0 : tensor<128x32xf32>
}

func.func public @test_reduce_minimum_4to1dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1, 2, 3] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<128xf32>
return %0 : tensor<128xf32>
}

func.func public @test_reduce_minimum_4to0dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<1xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [0, 1, 2, 3] : (tensor<128x10x32x4xf32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

func.func public @test_reduce_minimum_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128x4xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1] : (tensor<128x10x4xf32>, tensor<f32>) -> tensor<128x4xf32>
return %0 : tensor<128x4xf32>
}

func.func public @test_reduce_minimum_3to1dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1, 2] : (tensor<128x10x4xf32>, tensor<f32>) -> tensor<128xf32>
return %0 : tensor<128xf32>
}

func.func public @test_reduce_minimum_3to0dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<1xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [0, 1, 2] : (tensor<128x10x4xf32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

func.func public @test_reduce_minimum_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<128xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<128xf32>
return %0 : tensor<128xf32>
}

func.func public @test_reduce_minimum_2to0dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<1xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [0, 1] : (tensor<128x10xf32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

func.func public @test_reduce_minimum_1to0dim(%arg0: tensor<128xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.min"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128xf32>
// CHECK-SAME: -> tensor<1xf32>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.minimum across dimensions = [0] : (tensor<128xf32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
80 changes: 80 additions & 0 deletions test/ttmlir/Dialect/TTNN/reduction/simple_reduce_min.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s

module attributes {} {
func.func public @test_reduce_min_4to3dim(%arg0: tensor<128x10x32x4xf32>) -> tensor<128x32x4xf32> {
// CHECK-LABEL: func.func public @test_reduce_min_4to3dim
%0 = tensor.empty() : tensor<128x32x4xf32>
// CHECK: %[[MIN:[0-9]+]] = "ttnn.min"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: (tensor<128x10x32x4xf32,
// CHECK-SAME: -> tensor<128x1x32x4xf32,
// CHECK: "ttnn.reshape"(%[[MIN]])
// CHECK-SAME: shape = [128 : i32, 32 : i32, 4 : i32]
// CHECK-SAME: tensor<128x1x32x4xf32,
// CHECK-SAME: -> tensor<128x32x4xf32
%1 = "ttir.min"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128x32x4xf32>) -> tensor<128x32x4xf32>
return %1 : tensor<128x32x4xf32>
}

func.func public @test_reduce_min_4to0dim(%arg0: tensor<128x10x32x4xbf16>) -> tensor<1xbf16> {
// CHECK-LABEL: func.func public @test_reduce_min_4to0dim
%0 = tensor.empty() : tensor<1xbf16>
// CHECK-NOT: dim_arg = [1 : i32]
// CHECK: %[[MIN:[0-9]+]] = "ttnn.min"
// CHECK-SAME: keep_dim = true
// CHECK-SAME: (tensor<128x10x32x4xbf16,
// CHECK-SAME: -> tensor<1x1x1x1xbf16,
// CHECK: "ttnn.reshape"(%[[MIN]])
// CHECK-SAME: shape = [1 : i32]
// CHECK-SAME: tensor<1x1x1x1xbf16,
// CHECK-SAME: -> tensor<1xbf16
%1 = "ttir.min"(%arg0, %0) <{dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}

func.func public @test_reduce_min_3to2dim(%arg0: tensor<128x10x4xf32>) -> tensor<128x4xf32> {
// CHECK-LABEL: func.func public @test_reduce_min_3to2dim
%0 = tensor.empty() : tensor<128x4xf32>
// CHECK: %[[MIN:[0-9]+]] = "ttnn.min"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: (tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<128x1x4xf32,
// CHECK: "ttnn.reshape"(%[[MIN]])
// CHECK-SAME: shape = [128 : i32, 4 : i32]
// CHECK-SAME: tensor<128x1x4xf32,
// CHECK-SAME: -> tensor<128x4xf32
%1 = "ttir.min"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<128x10x4xf32>, tensor<128x4xf32>) -> tensor<128x4xf32>
return %1 : tensor<128x4xf32>
}

func.func public @test_reduce_min_3to0dim(%arg0: tensor<128x10x4xbf16>) -> tensor<1xbf16> {
// CHECK-LABEL: func.func public @test_reduce_min_3to0dim
%0 = tensor.empty() : tensor<1xbf16>
// CHECK-NOT: dim_arg = [1 : i32]
// CHECK: %[[MIN:[0-9]+]] = "ttnn.min"
// CHECK-SAME: keep_dim = true
// CHECK-SAME: (tensor<128x10x4xbf16,
// CHECK-SAME: -> tensor<1x1x1xbf16,
// CHECK: "ttnn.reshape"(%[[MIN]])
// CHECK-SAME: shape = [1 : i32]
// CHECK-SAME: tensor<1x1x1xbf16,
// CHECK-SAME: -> tensor<1xbf16
%1 = "ttir.min"(%arg0, %0) <{dim_arg = [0 : i32, 1 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}

func.func public @test_reduce_min_1to0dim(%arg0: tensor<128xbf16>) -> tensor<1xbf16> {
// CHECK-LABEL: func.func public @test_reduce_min_1to0dim
%0 = tensor.empty() : tensor<1xbf16>
// CHECK-NOT: dim_arg = [0 : i32]
// CHECK-NOT: ttnn.reshape
// CHECK: %[[MIN:[0-9]+]] = "ttnn.min"
// CHECK-SAME: keep_dim = true
// CHECK-SAME: (tensor<128xbf16,
// CHECK-SAME: -> tensor<1xbf16,
%1 = "ttir.min"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}
}
Loading

0 comments on commit 20bbd5c

Please sign in to comment.