diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index f17bbd018..99cae3ed2 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -217,6 +217,13 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> { }]; } +def TTIR_MeanOp : TTIR_ReductionOp<"mean"> { + let summary = "Mean reduction op."; + let description = [{ + Mean reduction op. + }]; +} + def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> { let summary = "Softmax operation."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 4f060c67d..3457dca6d 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -124,6 +124,13 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> { }]; } +def TTNN_MeanOp : TTNN_ReductionOp<"mean"> { + let summary = "Mean reduction op."; + let description = [{ + Mean reduction op. + }]; +} + def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> { let summary = "Eltwise ReLU."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index a2251c1cb..b9ab19242 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -40,6 +40,7 @@ table EltwiseOp { enum ReductionOpType: uint32 { Sum = 0, + Mean = 1, } table ReductionOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 9501ca7f2..831c00c7c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -144,6 +144,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseBinaryOpConversionPattern, ElementwiseBinaryOpConversionPattern, ReductionOpConversionPattern, + ReductionOpConversionPattern, SoftmaxOpConversionPattern, MatmulOpConversionPattern >(typeConverter, ctx); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index d0fe6c9d1..ddeb81633 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -119,6 +119,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Reduction ops // patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 836108ca4..b3e5d754e 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -535,6 +535,7 @@ class TTIRLayout : public impl::TTIRLayoutBase { TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( &getContext()); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 5660d8607..9d7376cb1 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -142,6 +142,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { ::tt::target::ttnn::ReductionOpType type; if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Sum; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::ReductionOpType::Mean; } else { llvm_unreachable("unhandled ReductionOp"); } @@ -209,6 +211,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto sumOp = dyn_cast(op); sumOp) { return createOperation(cache, createReductionOp(cache, sumOp), debugString); } + if (auto meanOp = dyn_cast(op); meanOp) { + return createOperation(cache, createReductionOp(cache, meanOp), + debugString); + } if (auto softmaxOp = dyn_cast(op); softmaxOp) { return createOperation(cache, createSoftmaxOp(cache, softmaxOp), debugString); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 0ee4cc88e..59ce43f54 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -247,6 +247,20 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device, liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); break; } + case ::tt::target::ttnn::ReductionOpType::Mean: { + auto &in = *liveTensors.at(op->in()->global_id()); + + const auto *dim_arg_fb_ptr = op->dim_arg(); + std::optional> dim_arg = + dim_arg_fb_ptr ? std::make_optional(std::vector( + dim_arg_fb_ptr->begin(), dim_arg_fb_ptr->end())) + : std::nullopt; + + tensorPool.push_back(::ttnn::mean(in, dim_arg, op->keep_dim())); + + liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + break; + } } } diff --git a/test/ttmlir/Dialect/TTNN/simple_mean.mlir b/test/ttmlir/Dialect/TTNN/simple_mean.mlir new file mode 100644 index 000000000..fe2f586af --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_mean.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {tt.system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { + // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + %0 = tensor.empty() : tensor<512x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] + %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: "ttnn.close_device"[[C:.*]] + return %1 : tensor<512x32xbf16> + } +}