Skip to content

Commit

Permalink
Implement conversion of stablehlo.constant OP to TTIR dialect (#684)
Browse files Browse the repository at this point in the history
* Implement conversion of stablehlo.constant OP to TTIR dialect

* Fix code review comments
  • Loading branch information
mrakitaTT authored Sep 13, 2024
1 parent d9b8485 commit 5b997c1
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 8 deletions.
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,26 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {
let hasVerifier = 1;
}

def TTIR_ConstantOp : TTIR_Op<"constant", [ConstantLike,
AllShapesMatch<["value", "result"]>]> {
let summary = "Constant op.";
let description = [{
Produces tensor filled with given constant value.

Examples:
%0 = "ttir.constant"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
// %0: [[0, 0, 0], [0, 0, 0]]
%1 = "ttir.constant"() {value = dense<[0.2, 1.3]> : tensor<2xf32>} : () -> tensor<2xf32>
// %1: [0.2, 1.3]
}];

let arguments = (ins ElementsAttr:$value);

let results = (outs AnyRankedTensor:$result);

let hasFolder = 1;
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
26 changes: 26 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,25 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
}
};

class StableHLOToTTIRConstantOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ConstantOp> {

using OpConversionPattern<mlir::stablehlo::ConstantOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ConstantOp srcOp,
mlir::stablehlo::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(srcOp.getResult().getType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
srcOp.getValue());

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -288,6 +307,12 @@ void addMatmulOpsConversionPatterns(MLIRContext *ctx,
ctx);
}

void addTensorCreationOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRConstantOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -300,6 +325,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReduceOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
addTensorCreationOpsConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
53 changes: 53 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,58 @@ class UnsqueezeOpConversionPattern
}
};

class ConstantOpConversionPattern
: public OpConversionPattern<ttir::ConstantOp> {
public:
using OpConversionPattern<ttir::ConstantOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
::mlir::ElementsAttr valueAttr = op.getValue();

LogicalResult legalityResult = checkBasicLegality(op, valueAttr, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

if (valueAttr.isSplat()) {
Value device = getOrInsertDevice(rewriter, op);
float fillValue = valueAttr.getElementType().isInteger()
? static_cast<float>(valueAttr.getSplatValue<int>())
: valueAttr.getSplatValue<float>();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<ttnn::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
} else {
::mlir::FloatAttr fillValueAttr = rewriter.getF32FloatAttr(fillValue);
rewriter.replaceOpWithNewOp<ttnn::FullOp>(
op, this->getTypeConverter()->convertType(op.getType()), device,
fillValueAttr);
}
} else {
return rewriter.notifyMatchFailure(
op, "TTNN doesn't currently support tensor creation from multiple "
"given values (issue #685)");
}

return success();
}

private:
LogicalResult checkBasicLegality(ttir::ConstantOp &op,
::mlir::ElementsAttr &valueAttr,
ConversionPatternRewriter &rewriter) const {
if (!valueAttr.getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(
op, "TTNN doesn't currently support tensor creation from values "
"which are not integer or floating point numbers");
}

return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -574,6 +626,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReshapeOpConversionPattern,
SqueezeOpConversionPattern,
UnsqueezeOpConversionPattern,
ConstantOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern
Expand Down
8 changes: 6 additions & 2 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
// SPDX-License-Identifier: Apache-2.0

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
#include <llvm/ADT/ArrayRef.h>

#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc"
#include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROps.cpp.inc"
Expand Down Expand Up @@ -400,6 +400,10 @@ ::mlir::LogicalResult mlir::tt::ttir::UnsqueezeOp::verify() {
return success();
}

::mlir::OpFoldResult mlir::tt::ttir::ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}

// ANCHOR: adding_an_op_matmul_ttir_verify
::mlir::LogicalResult mlir::tt::ttir::MatmulOp::verify() {
::mlir::RankedTensorType inputAType = getA().getType();
Expand Down
27 changes: 21 additions & 6 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,29 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
auto desiredLayout =
rewriter.getAttr<LayoutAttr>(ty, desiredMemorySpace, currLayout.getGrid(),
desiredElementType, desiredMemLayout);
auto output = rewriter.create<tensor::EmptyOp>(
loc, ty.getShape(), ty.getElementType(), desiredLayout);

tensor::EmptyOp exising_empty = input.getDefiningOp<tensor::EmptyOp>();
if (exising_empty) {
rewriter.replaceOp(exising_empty, output);
return output.getResult();
tensor::EmptyOp existingEmpty = input.getDefiningOp<tensor::EmptyOp>();
if (existingEmpty) {
return rewriter
.replaceOpWithNewOp<tensor::EmptyOp>(existingEmpty, ty.getShape(),
ty.getElementType(), desiredLayout)
.getResult();
}

ttir::ConstantOp existingConstant = input.getDefiningOp<ttir::ConstantOp>();
if (existingConstant) {
return rewriter
.replaceOpWithNewOp<ttir::ConstantOp>(
existingConstant,
mlir::RankedTensorType::get(ty.getShape(), ty.getElementType(),
desiredLayout),
existingConstant.getValue())
.getResult();
}

tensor::EmptyOp output = rewriter.create<tensor::EmptyOp>(
loc, ty.getShape(), ty.getElementType(), desiredLayout);

return rewriter
.create<ttir::ToLayoutOp>(loc, output.getType(), input, output)
->getResult(0);
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_constant attributes {} {
func.func public @test_splat() -> tensor<64xf32> {
%0 = stablehlo.constant dense<0.3> : tensor<64xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
return %0 : tensor<64xf32>
}

func.func public @test_multiple() -> tensor<2x2xf32> {
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
return %0 : tensor<2x2xf32>
}
}
27 changes: 27 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_constant.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @test_empty_int() -> tensor<64x128xi32> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xi32>
}

func.func @test_empty_float() -> tensor<64x128xf32> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
return %0 : tensor<64x128xf32>
}

func.func @test_full_int() -> tensor<64x128xi32> {
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xi32>
}

func.func @test_full_float() -> tensor<64x128xf32> {
%0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xf32>
}
}

0 comments on commit 5b997c1

Please sign in to comment.