From 89f16ead143abed81b825f20e34e4e77c6d117ad Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Fri, 5 Jan 2024 14:22:55 +0000 Subject: [PATCH] feat: add min to clamp and max to clamp canonicalizations. --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 + .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 86 +++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 28 ++++++ 3 files changed, 118 insertions(+) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 5823f5447522c..3331ca4cb8643 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -733,6 +733,8 @@ def Tosa_MaximumOp : Tosa_ElemWiseBinaryOp<"maximum", [Commutative]> { let results = (outs Tosa_Tensor:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -754,6 +756,8 @@ def Tosa_MinimumOp : Tosa_ElemWiseBinaryOp<"minimum", [Commutative]> { let results = (outs Tosa_Tensor:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 62d363b1c6349..636e0deb18a0e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -433,6 +433,92 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +struct MinToClampOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MinimumOp op, + PatternRewriter &rewriter) const override { + + DenseElementsAttr constant; + if (!matchPattern(op.getInput2(), m_Constant(&constant)) || + !constant.isSplat()) + return failure(); + + Value input = op.getInput1(); + auto elementTy = llvm::cast(input.getType()).getElementType(); + + int64_t minInt = std::numeric_limits::min(); + float minFp = std::numeric_limits::lowest(); + + int64_t maxInt; + float maxFp; + if (isa(elementTy)) { + auto constMin = constant.getSplatValue(); + maxFp = constMin.convertToFloat(); + maxInt = constMin.convertToFloat(); + } else { + auto constMin = constant.getSplatValue(); + maxFp = constMin.getSExtValue(); + maxInt = constMin.getSExtValue(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, rewriter.getI64IntegerAttr(minInt), + rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), + rewriter.getF32FloatAttr(maxFp)); + + return success(); + } +}; + +void MinimumOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +struct MaxToClampOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaximumOp op, + PatternRewriter &rewriter) const override { + + DenseElementsAttr constant; + if (!matchPattern(op.getInput2(), m_Constant(&constant)) || + !constant.isSplat()) + return failure(); + + Value input = op.getInput1(); + auto elementTy = llvm::cast(input.getType()).getElementType(); + + int64_t maxInt = std::numeric_limits::max(); + float maxFp = std::numeric_limits::max(); + + int64_t minInt; + float minFp; + if (isa(elementTy)) { + auto constMax = constant.getSplatValue(); + minFp = constMax.convertToFloat(); + minInt = constMax.convertToFloat(); + } else { + auto constMax = constant.getSplatValue(); + minFp = constMax.getSExtValue(); + minInt = constMax.getSExtValue(); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, rewriter.getI64IntegerAttr(minInt), + rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), + rewriter.getF32FloatAttr(maxFp)); + + return success(); + } +}; + +void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index c3fccb06663e3..31227698e09bf 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -112,6 +112,34 @@ func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { return %1 : tensor<4xi8> } +func.func @clamp_minimum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64} + %0 = "tosa.const"() <{value = dense<6> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.minimum"(%arg0, %0) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +func.func @clamp_minimum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: "tosa.clamp"(%arg0) <{max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64} + %0 = "tosa.const"() <{value = dense<6.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %1 = "tosa.minimum"(%arg0, %0) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +func.func @clamp_maximum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: "tosa.clamp"(%arg0) <{max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64} + %0 = "tosa.const"() <{value = dense<-6> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.maximum"(%arg0, %0) : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> + return %1 : tensor<4xi32> +} + +func.func @clamp_maximum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: "tosa.clamp"(%arg0) <{max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64} + %0 = "tosa.const"() <{value = dense<-6.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %1 = "tosa.maximum"(%arg0, %0) : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + // CHECK-LABEL: @concat_fold_zero func.func @concat_fold_zero(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: "tosa.concat"(%arg1, %arg2) <{axis = 1 : i64}>