diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp index 6c7f3643849f1..4d20d51b4ac7c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp @@ -17,7 +17,6 @@ #include "mlir/Pass/Pass.h" #include #include -#include #include #include #include @@ -47,31 +46,24 @@ struct TosaFoldConstantClamp : public OpRewritePattern { auto comparisonWidth = std::max(inputValues.getElementType().getIntOrFloatBitWidth(), lowerBound.getBitWidth()); + // Sign-extend the upper and lower bound + auto extUpperBound = upperBound.sext(comparisonWidth); + auto extLowerBound = lowerBound.sext(comparisonWidth); + // Determine the result type auto resultingIntType = cast(resultType.getElementType()); - // Ensure that the value is larger than the lower bound - auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val, - IntegerType type) { - auto clampedLower = llvm::APIntOps::smax( - val.sext(comparisonWidth), lowerBound.sext(comparisonWidth)); - // Make sure the output value has the correct type - assert(type.getWidth() >= clampedLower.getSignificantBits()); - return clampedLower.trunc(type.getWidth()); + // Lambda to perform the clamp + auto clampUpper = [&extLowerBound, &extUpperBound, + &comparisonWidth](const APInt &val, IntegerType type) { + auto clampedUpper = + llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound); + auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound); + assert(type.getWidth() >= fullyClamped.getSignificantBits()); + return fullyClamped.trunc(type.getWidth()); }; auto newTensor = applyElementWise( - inputValues, clampLower, resultingIntType); - - // Next, make sure the upper bound is adhered to - auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val, - IntegerType type) { - auto clampedUpper = llvm::APIntOps::smin( - val.sext(comparisonWidth), upperBound.sext(comparisonWidth)); - assert(type.getWidth() >= clampedUpper.getSignificantBits()); - return clampedUpper.trunc(type.getWidth()); - }; - newTensor = applyElementWise( - newTensor, clampUpper, resultingIntType); + inputValues, clampUpper, resultingIntType); return newTensor; } @@ -91,34 +83,22 @@ struct TosaFoldConstantClamp : public OpRewritePattern { auto resultingFloatType = cast(resultType.getElementType()); - // Ensure that the value is larger than the lower bound - auto clampLower = [&lowerBound, &comparisonSem](APFloat val, - FloatType type) { + // Ensure that the value is larger than the lower bound and smaller than the + // upper bound + auto clampLower = [&lowerBound, &upperBound, + &comparisonSem](APFloat val, FloatType type) { if (val.isNaN()) { return APFloat::getNaN(type.getFloatSemantics()); } changeSemanticsLossless(val, comparisonSem); - auto clampedLower = val < lowerBound ? lowerBound : val; - changeSemanticsLossless(clampedLower, &type.getFloatSemantics()); - return clampedLower; + auto clampedUpper = val < upperBound ? val : upperBound; + auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper; + changeSemanticsLossless(fullyClamped, &type.getFloatSemantics()); + return fullyClamped; }; auto newTensor = applyElementWise( inputValues, clampLower, resultingFloatType); - // Next, make sure the upper bound is adhered to - auto clampUpper = [&upperBound, &comparisonSem](APFloat val, - FloatType type) { - if (val.isNaN()) { - return APFloat::getNaN(type.getFloatSemantics()); - } - changeSemanticsLossless(val, comparisonSem); - auto clampedUpper = val < upperBound ? val : upperBound; - changeSemanticsLossless(clampedUpper, &type.getFloatSemantics()); - return clampedUpper; - }; - newTensor = applyElementWise( - newTensor, clampUpper, resultingFloatType); - return newTensor; } diff --git a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir index 585370a767d3c..276e87405e695 100644 --- a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir @@ -24,6 +24,17 @@ func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> { return %1 : tensor<3xi8> } +// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type +func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> + %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64} + : (tensor<3xi8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + // Float clamp // CHECK-LABEL: @clamp_fold_float @@ -64,3 +75,17 @@ func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> { : (tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } + +// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type +func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01 + // CHECK-NOT: tosa.clamp + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[18.32, -0.98747]> : + tensor<2xf16> + } : () -> tensor<2xf16> + %1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64} + : (tensor<2xf16>) -> tensor<2xf16> + return %1 : tensor<2xf16> +}