Skip to content

Commit

Permalink
Avoid multiple calls to applyElementWise
Browse files Browse the repository at this point in the history
* Merge lambdas that clamp to the upper and lower bound into a single
  one performing both
* Add tests with clamp boundaries which cannot be represented in the
  type of the value to be clamped
  • Loading branch information
TinaAMD committed Apr 21, 2023
1 parent 6497e4a commit 18c4a2e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
62 changes: 21 additions & 41 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantClamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "mlir/Pass/Pass.h"
#include <llvm/ADT/APFloat.h>
#include <llvm/ADT/APInt.h>
#include <llvm/Support/Debug.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Support/LogicalResult.h>
Expand Down Expand Up @@ -47,31 +46,24 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {
auto comparisonWidth =
std::max(inputValues.getElementType().getIntOrFloatBitWidth(),
lowerBound.getBitWidth());
// Sign-extend the upper and lower bound
auto extUpperBound = upperBound.sext(comparisonWidth);
auto extLowerBound = lowerBound.sext(comparisonWidth);

// Determine the result type
auto resultingIntType = cast<IntegerType>(resultType.getElementType());

// Ensure that the value is larger than the lower bound
auto clampLower = [&lowerBound, &comparisonWidth](const APInt &val,
IntegerType type) {
auto clampedLower = llvm::APIntOps::smax(
val.sext(comparisonWidth), lowerBound.sext(comparisonWidth));
// Make sure the output value has the correct type
assert(type.getWidth() >= clampedLower.getSignificantBits());
return clampedLower.trunc(type.getWidth());
// Lambda to perform the clamp
auto clampUpper = [&extLowerBound, &extUpperBound,
&comparisonWidth](const APInt &val, IntegerType type) {
auto clampedUpper =
llvm::APIntOps::smin(val.sext(comparisonWidth), extUpperBound);
auto fullyClamped = llvm::APIntOps::smax(clampedUpper, extLowerBound);
assert(type.getWidth() >= fullyClamped.getSignificantBits());
return fullyClamped.trunc(type.getWidth());
};
auto newTensor = applyElementWise<APInt, APInt, IntegerType>(
inputValues, clampLower, resultingIntType);

// Next, make sure the upper bound is adhered to
auto clampUpper = [&upperBound, &comparisonWidth](const APInt &val,
IntegerType type) {
auto clampedUpper = llvm::APIntOps::smin(
val.sext(comparisonWidth), upperBound.sext(comparisonWidth));
assert(type.getWidth() >= clampedUpper.getSignificantBits());
return clampedUpper.trunc(type.getWidth());
};
newTensor = applyElementWise<APInt, APInt, IntegerType>(
newTensor, clampUpper, resultingIntType);
inputValues, clampUpper, resultingIntType);

return newTensor;
}
Expand All @@ -91,34 +83,22 @@ struct TosaFoldConstantClamp : public OpRewritePattern<ClampOp> {

auto resultingFloatType = cast<FloatType>(resultType.getElementType());

// Ensure that the value is larger than the lower bound
auto clampLower = [&lowerBound, &comparisonSem](APFloat val,
FloatType type) {
// Ensure that the value is larger than the lower bound and smaller than the
// upper bound
auto clampLower = [&lowerBound, &upperBound,
&comparisonSem](APFloat val, FloatType type) {
if (val.isNaN()) {
return APFloat::getNaN(type.getFloatSemantics());
}
changeSemanticsLossless(val, comparisonSem);
auto clampedLower = val < lowerBound ? lowerBound : val;
changeSemanticsLossless(clampedLower, &type.getFloatSemantics());
return clampedLower;
auto clampedUpper = val < upperBound ? val : upperBound;
auto fullyClamped = clampedUpper < lowerBound ? lowerBound : clampedUpper;
changeSemanticsLossless(fullyClamped, &type.getFloatSemantics());
return fullyClamped;
};
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, clampLower, resultingFloatType);

// Next, make sure the upper bound is adhered to
auto clampUpper = [&upperBound, &comparisonSem](APFloat val,
FloatType type) {
if (val.isNaN()) {
return APFloat::getNaN(type.getFloatSemantics());
}
changeSemanticsLossless(val, comparisonSem);
auto clampedUpper = val < upperBound ? val : upperBound;
changeSemanticsLossless(clampedUpper, &type.getFloatSemantics());
return clampedUpper;
};
newTensor = applyElementWise<APFloat, APFloat, FloatType>(
newTensor, clampUpper, resultingFloatType);

return newTensor;
}

Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/Tosa/constant-clamp-opt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> {
return %1 : tensor<3xi8>
}

// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type
func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8>
// CHECK-NOT: tosa.clamp
// CHECK: return [[RES]]
%0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8>
%1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64}
: (tensor<3xi8>) -> tensor<3xi8>
return %1 : tensor<3xi8>
}

// Float clamp

// CHECK-LABEL: @clamp_fold_float
Expand Down Expand Up @@ -64,3 +75,17 @@ func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> {
: (tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<5xf32>
}

// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type
func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01
// CHECK-NOT: tosa.clamp
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[18.32, -0.98747]> :
tensor<2xf16>
} : () -> tensor<2xf16>
%1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64}
: (tensor<2xf16>) -> tensor<2xf16>
return %1 : tensor<2xf16>
}

0 comments on commit 18c4a2e

Please sign in to comment.