Skip to content

Commit

Permalink
Add E2E implementation of reduce prod op along with StableHLO convers…
Browse files Browse the repository at this point in the history
…ion (#1792)
  • Loading branch information
mmanzoorTT authored Jan 23, 2025
1 parent 59d6008 commit d5b4d1a
Show file tree
Hide file tree
Showing 28 changed files with 584 additions and 5 deletions.
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,27 @@ def TTIR_MinOp : TTIR_ReductionOp<"min"> {
}];
}

def TTIR_ProdOp : TTIR_ReductionOp<"prod"> {
let summary = "Product reduction op.";
let description = [{
This op computes the product of all elements of the tensor (full product)
or along a specific dimension.

Example:
input: [[1, 2, 3],
[4, 5, 6]]

// Computing along dim 0
output: [4, 10, 18]

// Computing along dim 1
output: [6, 120]

// Computing full product
output: 720
}];
}

def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
31 changes: 31 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,37 @@ def TTNN_MinOp : TTNN_ReductionOp<"min"> {
}];
}

def TTNN_ProdOp : TTNN_Op<"prod"> {
let summary = "Product reduction op.";
let description = [{
This op computes the product of all elements of the tensor (full product)
or along a specific dimension.

Example:
input: [[1, 2, 3],
[4, 5, 6]]

// Computing along dim 0
output: [4, 10, 18]

// Computing along dim 1
output: [6, 120]

// Computing full product
output: 720
}];

let arguments = (ins AnyRankedTensor:$input,
BoolAttr:$all_dimensions,
BoolAttr:$keep_dim,
I64Attr:$dim_arg,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ table ReductionOp {
keep_dim: bool;
}

table ReductionProdOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
all_dimensions: bool;
dim_arg: int64;
keep_dim: bool;
memcfg: tt.target.MemoryConfigDesc;
}

table EmbeddingOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
Expand Down Expand Up @@ -369,6 +378,7 @@ union OpType {
LinearOp,
MatmulOp,
ReductionOp,
ReductionProdOp,
EmbeddingOp,
EmbeddingBackwardOp,
RepeatInterleaveOp,
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::MinOp>(srcOp, adaptor,
rewriter);
}
if (mlir::isa<mlir::stablehlo::MulOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::ProdOp>(srcOp, adaptor,
rewriter);
}

return failure();
}
Expand Down
35 changes: 35 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,40 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {
}
};

class ReductionProdOpConversionPattern
: public OpConversionPattern<ttir::ProdOp> {
public:
using OpConversionPattern<ttir::ProdOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ProdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int64_t inputRank = op.getInput().getType().getRank();
auto dimArg = op.getDimArg();
int64_t size = dimArg ? dimArg->size() : inputRank;

// [TODO](mmanzoor) Decompose ttnn.prod op into multiple ttnn.prod to handle
// reduction along multiple dimensions.
// https://github.com/tenstorrent/tt-mlir/issues/1861
if ((size > 1) && (size < inputRank)) {
return rewriter.notifyMatchFailure(
op, "tt-metal only supports reduce(prod) along one dimension or all "
"dimensions.");
}

bool allDimensions = (size == inputRank) ? true : false;
int64_t dimension =
dimArg ? (mlir::cast<mlir::IntegerAttr>(dimArg->getValue()[0])).getInt()
: 0;

rewriter.replaceOpWithNewOp<ttnn::ProdOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), allDimensions, adaptor.getKeepDim(), dimension,
/*memoryConfig*/ nullptr);
return success();
}
};

class EmbeddingOpConversionPattern
: public OpConversionPattern<ttir::EmbeddingOp> {
public:
Expand Down Expand Up @@ -1332,6 +1366,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ReductionOpConversionPattern<ttir::MinOp, ttnn::MinOp>,
ReductionProdOpConversionPattern,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>,
DefaultOpConversionPattern<ttnn::MeanOp>,
DefaultOpConversionPattern<ttnn::MaxOp>,
DefaultOpConversionPattern<ttnn::MinOp>>(typeConverter, ctx);
DefaultOpConversionPattern<ttnn::MinOp>,
DefaultOpConversionPattern<ttnn::ProdOp>>(typeConverter, ctx);

// Conv ops
//
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2276,3 +2276,19 @@ void mlir::tt::ttir::MinOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::LogicalResult mlir::tt::ttir::MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ProdOp
//===----------------------------------------------------------------------===//

// ProdOp kernel builder.
void mlir::tt::ttir::ProdOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,
::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "prod");
}

// ProdOp verification.
::mlir::LogicalResult mlir::tt::ttir::ProdOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}
32 changes: 32 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,28 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
return mlir::success();
}

// Verifier for Reduce ProdOp.
static mlir::LogicalResult verifyReduceProdOp(mlir::Operation *reduceOp,
mlir::RankedTensorType inputType,
bool allDimensions) {
int64_t inputTensorRank = inputType.getRank();
mlir::Type elementType = inputType.getElementType();

if (inputTensorRank > 4) {
return reduceOp->emitOpError(
"Input tensor rank is greater than 4 for reduce(product).");
}
// [TODO](mmanzoor) Add workaround to typecast the input tensor to bfloat16
// then typecast the output again to match the requirements.
// https://github.com/tenstorrent/tt-mlir/issues/1864
if (allDimensions && !elementType.isBF16()) {
return reduceOp->emitOpError("TTNN only supports Reduce(prod) along all "
"dimensions for bfloat16 datatype.");
}

return mlir::success();
}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1582,4 +1604,14 @@ ::mlir::LogicalResult MinOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ProdOp
//===----------------------------------------------------------------------===//

// ProdOp verification.
::mlir::LogicalResult ProdOp::verify() {
return verifyReduceProdOp(getOperation(), getInput().getType(),
getAllDimensions());
}

} // namespace mlir::tt::ttnn
25 changes: 23 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,11 +787,28 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto dim_arg =
auto dimArg =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());

return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output,
dim_arg, op.getKeepDim());
dimArg, op.getKeepDim());
}

template <typename ReductionOp>
::flatbuffers::Offset<::tt::target::ttnn::ReductionProdOp>
createReductionProdOp(FlatbufferObjectCache &cache, ReductionOp op) {
auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto memoryConfigDesc =
op.getMemoryConfig()
? cache.getOrCreate(*op.getMemoryConfig(), memoryConfigToFlatbuffer)
: 0;

return ::tt::target::ttnn::CreateReductionProdOp(
*cache.fbb, in, output, op.getAllDimensions(), op.getDimArg(),
op.getKeepDim(), memoryConfigDesc);
}

::flatbuffers::Offset<::tt::target::ttnn::TransposeOp>
Expand Down Expand Up @@ -1173,6 +1190,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createReductionOp(cache, minOp), debugString,
locInfo);
}
if (auto prodOp = dyn_cast<ProdOp>(op); prodOp) {
return createOperation(cache, createReductionProdOp(cache, prodOp),
debugString, locInfo);
}
if (auto embeddingOp = dyn_cast<EmbeddingOp>(op); embeddingOp) {
return createOperation(cache, createEmbeddingOp(cache, embeddingOp),
debugString, locInfo);
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/pool/generic/generic_pools.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/operations/reduction/prod/prod.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/host_buffer/owned_buffer.hpp"
#include "ttnn/tensor/shape/shape.hpp"
Expand Down
1 change: 1 addition & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(TTNN_OPS_SRCS
# ANCHOR_END: adding_an_op_matmul_runtime_cmake
${CMAKE_CURRENT_SOURCE_DIR}/normalization/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/pool/maxpool2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduction/prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduction/reduction.cpp
${CMAKE_CURRENT_SOURCE_DIR}/context/get_device.cpp
)
Expand Down
35 changes: 35 additions & 0 deletions runtime/lib/ttnn/operations/reduction/prod.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "operations/reduction/prod.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::operations::reduction {
static void runReductionProdOp(::tt::target::ttnn::ReductionProdOp const *op,
ProgramTensorPool &tensorPool) {

std::optional<::tt::tt_metal::MemoryConfig> outputMemoryConfig =
op->memcfg() ? std::make_optional(
utils::createMemoryConfig(op->memcfg(), op->out()))
: std::nullopt;

const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());

::ttnn::Tensor out =
::ttnn::prod(in, op->all_dimensions(), op->dim_arg(), op->keep_dim(),
outputMemoryConfig /* memory_config_arg */);

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::ReductionProdOp *op,
ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
runReductionProdOp(op, tensorPool);
}
} // namespace tt::runtime::ttnn::operations::reduction
16 changes: 16 additions & 0 deletions runtime/lib/ttnn/operations/reduction/prod.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef RUNTIME_LIB_TTNN_OPERATIONS_REDUCTION_PROD_H
#define RUNTIME_LIB_TTNN_OPERATIONS_REDUCTION_PROD_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::reduction {
void run(const ::tt::target::ttnn::ReductionProdOp *op,
ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::reduction

#endif
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "operations/matmul/matmul.h"
#include "operations/normalization/softmax.h"
#include "operations/pool/maxpool2d.h"
#include "operations/reduction/prod.h"
#include "operations/reduction/reduction.h"
#include "tt/runtime/detail/debug.h"
#include "tt/runtime/detail/logger.h"
Expand Down Expand Up @@ -189,6 +190,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) {
return operations::matmul::run(op->type_as_MatmulOp(), context);
}
// ANCHOR_END: adding_an_op_matmul_runtime_program
case ::tt::target::ttnn::OpType::ReductionProdOp: {
return operations::reduction::run(op->type_as_ReductionProdOp(), context);
}
case ::tt::target::ttnn::OpType::ReductionOp: {
return operations::reduction::run(op->type_as_ReductionOp(), context);
}
Expand Down
Loading

0 comments on commit d5b4d1a

Please sign in to comment.