Skip to content

Commit

Permalink
Introduce a TTIR decomposition pass (#969)
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT authored Oct 30, 2024
1 parent 22a06f2 commit 38a4a46
Show file tree
Hide file tree
Showing 17 changed files with 556 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#endif
#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h"
#include "ttmlir/Conversion/TTIRToTTMetal/TTIRToTTMetal.h"
#include "ttmlir/Conversion/TTIRToTTNN/TTIRToTTNN.h"
#include "ttmlir/Conversion/TTKernelToEmitC/TTKernelToEmitC.h"
Expand Down
6 changes: 6 additions & 0 deletions include/ttmlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def ConvertTosaToTTIR : Pass<"convert-tosa-to-ttir", "::mlir::ModuleOp"> {
let dependentDialects = ["mlir::tt::ttir::TTIRDialect"];
}

def TTIRToTTIRDecomposition: Pass<"ttir-to-ttir-decomposition", "::mlir::ModuleOp"> {
let summary = "Decomposes TTIR operations into simpler TTIR operations.";
let constructor = "createTTIRToTTIRDecompositionPass()";
let dependentDialects = ["mlir::tt::ttir::TTIRDialect"];
}

def ConvertTTIRToTTNN: Pass<"convert-ttir-to-ttnn", "::mlir::ModuleOp"> {
let summary = "Convert TTIR dialect to TTNN dialect.";
let constructor = "createConvertTTIRToTTNNPass()";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H
#define TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::tt {

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter);

std::unique_ptr<OperationPass<ModuleOp>> createTTIRToTTIRDecompositionPass();

} // namespace mlir::tt

#endif // TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H
30 changes: 28 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,9 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
let summary = "Slice op.";
let description = [{
Extract a portion of a tensor based on the specified start (`begins`), stop (`ends`), and step
indices for each dimension.
Extract a sub-tensor (slice) from the input tensor across one or more dimensions.
The `begins`, `ends`, and `step` attributes specify the start, stop, and step indices
for each dimension of the tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
Expand All @@ -673,6 +674,31 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
let hasVerifier = 1;
}

def TTIR_IndexOp: TTIR_DPSOp<"index"> {
let summary = "Index op.";
let description = [{
Extract a sub-tensor (slice) from the input tensor along a specified dimension.
The `begin`, `end`, and `step` attributes define the start, stop, and step indices for the
selected dimension (`dim`) of the tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32Attr:$dim,
I32Attr:$begin,
I32Attr:$end,
I32Attr:$step,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> {
let summary = "Squeeze op.";
let description = [{
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(TTMLIRConversions INTERFACE)

add_subdirectory(TosaToTTIR)
add_subdirectory(TTIRToTTIRDecomposition)
add_subdirectory(TTNNToEmitC)
add_subdirectory(TTIRToTTNN)
add_subdirectory(TTIRToTTMetal)
Expand All @@ -14,6 +15,7 @@ include_directories(${TTMLIR_SOURCE_DIR}/include)

set(link_libs
TTMLIRTosaToTTIR
TTMLIRTTIRToTTIRDecomposition
TTMLIRTTNNToEmitC
TTMLIRTTIRToTTNN
TTMLIRTTIRToTTMetal
Expand Down
14 changes: 14 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
add_mlir_library(TTMLIRTTIRToTTIRDecomposition
TTIRToTTIRDecomposition.cpp
TTIRToTTIRDecompositionPass.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TTIRToTTIR

DEPENDS
TTMLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
)
83 changes: 83 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include <llvm/Support/raw_ostream.h>
#include <mlir/Support/LLVM.h>

using namespace mlir;
using namespace mlir::tt;

namespace mlir::tt {

// Decompose IndexOp into SliceOp
//
// This transformation adjusts IndexOp attributes so that `begin`, `end`, and
// `step` become arrays, where each array element corresponds to a dimension of
// the input tensor. For dimensions other than the sliced dimension, default
// values are used.
//
struct IndexToSliceConversionPattern
: public OpConversionPattern<ttir::IndexOp> {
using OpConversionPattern<ttir::IndexOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::IndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType =
::mlir::dyn_cast<mlir::RankedTensorType>(adaptor.getInput().getType());
if (!inputType || !inputType.hasRank())
return failure();

int64_t rank = inputType.getRank();
llvm::SmallVector<mlir::Attribute, 4> begins, ends, steps;

for (int64_t i = 0; i < rank; ++i) {
if (i == op.getDim()) {
begins.push_back(rewriter.getI32IntegerAttr(adaptor.getBegin()));
ends.push_back(rewriter.getI32IntegerAttr(adaptor.getEnd()));
steps.push_back(rewriter.getI32IntegerAttr(adaptor.getStep()));
} else {
begins.push_back(rewriter.getI32IntegerAttr(0));
ends.push_back(rewriter.getI32IntegerAttr(inputType.getDimSize(i)));
steps.push_back(rewriter.getI32IntegerAttr(1));
}
}

auto newOp = rewriter.create<ttir::SliceOp>(
op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(),
rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends),
rewriter.getArrayAttr(steps), adaptor.getOperandConstraints());

rewriter.replaceOp(op, newOp.getResult());
return success();
}
};

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h"

#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>

using namespace mlir;
using namespace mlir::tt;

namespace mlir::tt::ttir {

#define GEN_PASS_DEF_TTIRTOTTIRDECOMPOSITION
#include "ttmlir/Conversion/Passes.h.inc"

} // namespace mlir::tt::ttir

namespace {

struct TTIRToTTIRDecompositionPass
: public ttir::impl::TTIRToTTIRDecompositionBase<
TTIRToTTIRDecompositionPass> {
void runOnOperation() final {
mlir::ConversionTarget target(getContext());
target.addLegalDialect<ttir::TTIRDialect>();

target.addIllegalOp<ttir::IndexOp>();

TypeConverter typeConverter;
// All types map 1:1.
typeConverter.addConversion([](Type type) { return type; });

RewritePatternSet patterns(&getContext());
populateTTIRToTTIRDecompositionPatterns(&getContext(), patterns,
typeConverter);

// Apply partial conversion
//
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
return;
}
}
};

} // namespace

namespace mlir::tt {

std::unique_ptr<OperationPass<ModuleOp>> createTTIRToTTIRDecompositionPass() {
return std::make_unique<TTIRToTTIRDecompositionPass>();
}

} // namespace mlir::tt
3 changes: 2 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ class ToLayoutOpConversionPattern
private:
bool shouldForceRowMajor(ttir::ToLayoutOp op) const {
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::MaxPool2dOp>(user)) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::MaxPool2dOp>(user) ||
isa<ttir::SliceOp>(user)) {
return true;
}
}
Expand Down
110 changes: 110 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,116 @@ ::mlir::LogicalResult mlir::tt::ttir::SliceOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// IndexOp
//===----------------------------------------------------------------------===//

// IndexOp verification
::mlir::LogicalResult mlir::tt::ttir::IndexOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
::mlir::RankedTensorType outputType = getOutput().getType();
int32_t dim = getDim();
int32_t begin = getBegin();
int32_t end = getEnd();
int32_t step = getStep();

// Verify that the input is at least 1D tensor
if (inputType.getRank() < 1) {
return emitOpError("Input must be at least a 1D tensor");
}

// Validate that the output tensor has the same element type as the input
// tensor
if (inputType.getElementType() != outputType.getElementType()) {
return emitOpError(
"Output tensor must have the same element type as the input tensor");
}

// Verify the output tensor rank
if (inputType.getRank() != outputType.getRank()) {
return emitOpError(
"Output tensor must have the same rank as the input tensor");
}

// Verify that the dim attribute is within the bounds of the input tensor
if (dim < 0 || dim >= inputType.getRank()) {
return emitOpError() << "Invalid dimension index " << dim
<< ". Input tensor rank is " << inputType.getRank();
}

// Verify begin, end, step and the output tensor dimensions
int64_t dimSize = inputShape[dim];

// Adjust negative begin and end
int32_t adjustedBegin = (begin < 0) ? (begin + dimSize) : begin;
int32_t adjustedEnd = (end < 0) ? (end + dimSize) : end;

std::ostringstream inputShapeStream;
inputShapeStream << "(";
for (size_t i = 0; i < inputShape.size(); ++i) {
inputShapeStream << inputShape[i];
if (i != inputShape.size() - 1) {
inputShapeStream << ", ";
}
}
inputShapeStream << ")";
std::string inputShapeStr = inputShapeStream.str();

if (adjustedBegin < 0 || adjustedBegin >= dimSize) {
return emitOpError() << "Invalid begin index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "), got " << begin
<< ". Input shape: " << inputShapeStr;
}
if (adjustedEnd < 0 || adjustedEnd > dimSize) {
return emitOpError() << "Invalid end index for dimension "
<< std::to_string(dim) << ". Expected value in range ["
<< std::to_string(-dimSize) << ", " << dimSize
<< "], got " << end
<< ". Input shape: " << inputShapeStr;
}

auto formatValueMessage = [](int value, int adjustedValue) {
return value < 0 ? std::to_string(adjustedValue) + " (" +
std::to_string(value) + ")"
: std::to_string(value);
};
std::string beginValueMessage = formatValueMessage(begin, adjustedBegin);
std::string endValueMessage = formatValueMessage(end, adjustedEnd);

if (step == 0) {
return emitOpError("Step value for dimension " + std::to_string(dim) +
" cannot be zero");
} else if (step > 0 && adjustedBegin > adjustedEnd) {
return emitOpError() << "For positive step, begin index must be less "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
} else if (step < 0 && adjustedBegin < adjustedEnd) {
return emitOpError() << "For negative step, begin index must be greater "
"than or equal to end index for dimension "
<< dim << ". Got begin: " << beginValueMessage
<< ", end: " << endValueMessage << ", step: " << step
<< ", input shape: " << inputShapeStr;
}

// Calculate the expected size of the output dimension
int32_t expectedDimSize =
(std::abs(adjustedEnd - adjustedBegin) + std::abs(step) - 1) /
std::abs(step);
if (outputType.getDimSize(dim) != expectedDimSize) {
return emitOpError() << "Mismatch in dimension " << std::to_string(dim)
<< " of the output tensor: expected size "
<< expectedDimSize << ", but got "
<< outputType.getDimSize(dim);
}

return success();
}

//===----------------------------------------------------------------------===//
// SqueezeOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 38a4a46

Please sign in to comment.