Skip to content

Commit

Permalink
Implement reduction mean op end to end (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
dgolubovicTT authored Aug 5, 2024
1 parent 4684283 commit f818088
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
}];
}

def TTIR_MeanOp : TTIR_ReductionOp<"mean"> {
let summary = "Mean reduction op.";
let description = [{
Mean reduction op.
}];
}

def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {
let summary = "Softmax operation.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
}];
}

def TTNN_MeanOp : TTNN_ReductionOp<"mean"> {
let summary = "Mean reduction op.";
let description = [{
Mean reduction op.
}];
}

def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ table EltwiseOp {

enum ReductionOpType: uint32 {
Sum = 0,
Mean = 1,
}

table ReductionOp {
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseBinaryOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
ElementwiseBinaryOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
SoftmaxOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Reduction ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MeanOp>>(typeConverter, ctx);
}

} // namespace mlir::tt
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<GreaterEqualOp>,
TTIRLayoutOperandsRewriter<ReluOp>, TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<MeanOp>,
TTIRLayoutOperandsRewriter<SoftmaxOp>,
TTIRLayoutOperandsRewriter<MatmulOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
::tt::target::ttnn::ReductionOpType type;
if constexpr (std::is_same_v<ReductionOp, SumOp>) {
type = ::tt::target::ttnn::ReductionOpType::Sum;
} else if constexpr (std::is_same_v<ReductionOp, MeanOp>) {
type = ::tt::target::ttnn::ReductionOpType::Mean;
} else {
llvm_unreachable("unhandled ReductionOp");
}
Expand Down Expand Up @@ -209,6 +211,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto sumOp = dyn_cast<SumOp>(op); sumOp) {
return createOperation(cache, createReductionOp(cache, sumOp), debugString);
}
if (auto meanOp = dyn_cast<MeanOp>(op); meanOp) {
return createOperation(cache, createReductionOp(cache, meanOp),
debugString);
}
if (auto softmaxOp = dyn_cast<SoftmaxOp>(op); softmaxOp) {
return createOperation(cache, createSoftmaxOp(cache, softmaxOp),
debugString);
Expand Down
14 changes: 14 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,20 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::ReductionOpType::Mean: {
auto &in = *liveTensors.at(op->in()->global_id());

const auto *dim_arg_fb_ptr = op->dim_arg();
std::optional<vector<int>> dim_arg =
dim_arg_fb_ptr ? std::make_optional(std::vector<int>(
dim_arg_fb_ptr->begin(), dim_arg_fb_ptr->end()))
: std::nullopt;

tensorPool.push_back(::ttnn::mean(in, dim_arg, op->keep_dim()));

liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
}
}

Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_mean.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
%0 = tensor.empty() : tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]]
%1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<512x32xbf16>
}
}

0 comments on commit f818088

Please sign in to comment.