-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce a TTIR decomposition pass (#969)
- Loading branch information
1 parent
22a06f2
commit 38a4a46
Showing
17 changed files
with
556 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
21 changes: 21 additions & 0 deletions
21
include/ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H | ||
#define TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir::tt { | ||
|
||
void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, | ||
RewritePatternSet &patterns, | ||
TypeConverter &typeConverter); | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createTTIRToTTIRDecompositionPass(); | ||
|
||
} // namespace mlir::tt | ||
|
||
#endif // TTMLIR_CONVERSION_TTIRTOTTIRDECOMPOSITION_TTIRTOTTIRDECOMPOSITION_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
add_mlir_library(TTMLIRTTIRToTTIRDecomposition | ||
TTIRToTTIRDecomposition.cpp | ||
TTIRToTTIRDecompositionPass.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TTIRToTTIR | ||
|
||
DEPENDS | ||
TTMLIRConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
) |
83 changes: 83 additions & 0 deletions
83
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h" | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" | ||
#include "ttmlir/Dialect/TTNN/Utils/Utils.h" | ||
|
||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Traits.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Types.h" | ||
#include "mlir/IR/ValueRange.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/ErrorHandling.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include <llvm/Support/raw_ostream.h> | ||
#include <mlir/Support/LLVM.h> | ||
|
||
using namespace mlir; | ||
using namespace mlir::tt; | ||
|
||
namespace mlir::tt { | ||
|
||
// Decompose IndexOp into SliceOp | ||
// | ||
// This transformation adjusts IndexOp attributes so that `begin`, `end`, and | ||
// `step` become arrays, where each array element corresponds to a dimension of | ||
// the input tensor. For dimensions other than the sliced dimension, default | ||
// values are used. | ||
// | ||
struct IndexToSliceConversionPattern | ||
: public OpConversionPattern<ttir::IndexOp> { | ||
using OpConversionPattern<ttir::IndexOp>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(ttir::IndexOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto inputType = | ||
::mlir::dyn_cast<mlir::RankedTensorType>(adaptor.getInput().getType()); | ||
if (!inputType || !inputType.hasRank()) | ||
return failure(); | ||
|
||
int64_t rank = inputType.getRank(); | ||
llvm::SmallVector<mlir::Attribute, 4> begins, ends, steps; | ||
|
||
for (int64_t i = 0; i < rank; ++i) { | ||
if (i == op.getDim()) { | ||
begins.push_back(rewriter.getI32IntegerAttr(adaptor.getBegin())); | ||
ends.push_back(rewriter.getI32IntegerAttr(adaptor.getEnd())); | ||
steps.push_back(rewriter.getI32IntegerAttr(adaptor.getStep())); | ||
} else { | ||
begins.push_back(rewriter.getI32IntegerAttr(0)); | ||
ends.push_back(rewriter.getI32IntegerAttr(inputType.getDimSize(i))); | ||
steps.push_back(rewriter.getI32IntegerAttr(1)); | ||
} | ||
} | ||
|
||
auto newOp = rewriter.create<ttir::SliceOp>( | ||
op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(), | ||
rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends), | ||
rewriter.getArrayAttr(steps), adaptor.getOperandConstraints()); | ||
|
||
rewriter.replaceOp(op, newOp.getResult()); | ||
return success(); | ||
} | ||
}; | ||
|
||
void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, | ||
RewritePatternSet &patterns, | ||
TypeConverter &typeConverter) { | ||
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx); | ||
} | ||
|
||
} // namespace mlir::tt |
67 changes: 67 additions & 0 deletions
67
lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h" | ||
|
||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIR.h" | ||
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNN.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" | ||
#include <mlir/Dialect/Func/IR/FuncOps.h> | ||
#include <mlir/Dialect/Tensor/IR/Tensor.h> | ||
|
||
using namespace mlir; | ||
using namespace mlir::tt; | ||
|
||
namespace mlir::tt::ttir { | ||
|
||
#define GEN_PASS_DEF_TTIRTOTTIRDECOMPOSITION | ||
#include "ttmlir/Conversion/Passes.h.inc" | ||
|
||
} // namespace mlir::tt::ttir | ||
|
||
namespace { | ||
|
||
struct TTIRToTTIRDecompositionPass | ||
: public ttir::impl::TTIRToTTIRDecompositionBase< | ||
TTIRToTTIRDecompositionPass> { | ||
void runOnOperation() final { | ||
mlir::ConversionTarget target(getContext()); | ||
target.addLegalDialect<ttir::TTIRDialect>(); | ||
|
||
target.addIllegalOp<ttir::IndexOp>(); | ||
|
||
TypeConverter typeConverter; | ||
// All types map 1:1. | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
|
||
RewritePatternSet patterns(&getContext()); | ||
populateTTIRToTTIRDecompositionPatterns(&getContext(), patterns, | ||
typeConverter); | ||
|
||
// Apply partial conversion | ||
// | ||
if (failed(applyPartialConversion(getOperation(), target, | ||
std::move(patterns)))) { | ||
signalPassFailure(); | ||
return; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir::tt { | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createTTIRToTTIRDecompositionPass() { | ||
return std::make_unique<TTIRToTTIRDecompositionPass>(); | ||
} | ||
|
||
} // namespace mlir::tt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.