-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" #95344
Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" #95344
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir-math Author: Ivy Zhang (crazydemo) ChangesReverts llvm/llvm-project#93443 Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95344.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c4471f9bc5af2..06fbdb7f2c4cb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
// ExtFOp
//===----------------------------------------------------------------------===//
-def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,13 +1208,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
}];
let hasVerifier = 1;
let hasFolder = 1;
-
- let arguments = (ins FloatLike:$in, DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
- let results = (outs FloatLike:$out);
-
- let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
- attr-dict `:` type($in) `to` type($out) }];
}
//===----------------------------------------------------------------------===//
@@ -1253,11 +1246,8 @@ def Arith_TruncFOp :
Arith_Op<"truncf",
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
- DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
- DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
@@ -1277,9 +1267,7 @@ def Arith_TruncFOp :
let hasFolder = 1;
let hasVerifier = 1;
- let assemblyFormat = [{ $in ($roundingmode^)?
- (`fastmath` `` $fastmath^)?
- attr-dict `:` type($in) `to` type($out) }];
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 291f6e5424ba5..2f6647a2a27b1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,20 +1390,6 @@ LogicalResult arith::ExtSIOp::verify() {
/// Fold extension of float constants when there is no information loss due the
/// difference in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
- if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
- if (truncFOp.getOperand().getType() == getType()) {
- arith::FastMathFlags truncFMF = truncFOp.getFastmath();
- bool isTruncContract =
- bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
- arith::FastMathFlags extFMF = getFastmath();
- bool isExtContract =
- bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
- if (isTruncContract && isExtContract) {
- return truncFOp.getOperand();
- }
- }
- }
-
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 8e1cb474feee7..4a50da3513f99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
SmallVector<Value> newResults(expandedOp->getResults());
for (auto [res, oldType, newType] : llvm::zip_equal(
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
- if (oldType != newType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- res = truncFOp.getResult();
- }
+ if (oldType != newType)
+ res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
}
rewriter.replaceOp(op, newResults);
}
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
});
converter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf56..5998133b7eab8 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
});
typeConverter.addTargetMaterialization(
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
- auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
- extFOp.setFastmath(arith::FastMathFlags::contract);
- return extFOp;
+ return b.create<arith::ExtFOp>(loc, target, input);
});
}
@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
SmallVector<Value> results = (*legalized)->getResults();
for (auto [result, newType, origType] : llvm::zip_equal(
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
- if (newType != origType) {
- auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
- truncFOp.setFastmath(arith::FastMathFlags::contract);
- result = truncFOp.getResult();
- }
+ if (newType != origType)
+ result = rewriter.create<arith::TruncFOp>(loc, origType, result);
}
rewriter.replaceOp(op, results);
return success();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cacdd801871fb..56ae930e6d627 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -162,11 +162,11 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
%0 = arith.extf %arg0: f16 to f32
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
%1 = arith.extf %arg0: f16 to f64
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
%2 = arith.extf %arg1: f32 to f64
return
}
@@ -174,11 +174,11 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
%0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
%1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
%2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
return
}
@@ -268,11 +268,11 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
%0 = arith.truncf %arg0: f32 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
%1 = arith.truncf %arg1: f64 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
%2 = arith.truncf %arg1: f64 to f32
return
}
@@ -280,26 +280,26 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
%0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
%1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
%2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
return
}
// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
%1 = arith.truncf %arg0 downward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
%2 = arith.truncf %arg0 upward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
%3 = arith.truncf %arg0 toward_zero : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
%4 = arith.truncf %arg0 to_nearest_away : f64 to f32
return
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,143 +3031,6 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
}
-// CHECK-LABEL: @sequences_fastmath_contract
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
- %3 = arith.extf %2 fastmath<contract> : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @sequences_no_fastmath
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
-// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
- %0 = arith.extf %arg0 : bf16 to f32
- %1 = math.absf %0 : f32
- %2 = arith.truncf %1 : f32 to bf16
- %3 = arith.extf %2 : bf16 to f32
- %4 = math.sin %3 : f32
- %5 = arith.truncf %4 : f32 to bf16
- return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminate_cast_to_f16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
- %1 = arith.extf %0 fastmath<contract> : f16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @eliminate_cast_to_bf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
- %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
- %1 = arith.extf %0 fastmath<contract> : bf16 to f32
- return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
- return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.cos %3 : vector<32x32x32xf32>
- %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %10 = arith.addf %6, %9 : vector<32x32x32xf32>
- %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
- %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %1 = math.absf %0 : vector<32x32x32xf32>
- %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %4 = math.sin %3 : vector<32x32x32xf32>
- %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
- %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
- %9 = arith.addf %8, %6 : vector<32x32x32xf32>
- %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
- return %10 : vector<32x32x32xbf16>
-}
-
{-#
dialect_resources: {
builtin: {
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 99790cc45d490..a69ef131d8d47 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
// CHECK-LABEL: @basic_expansion
// CHECK-SAME: [[X:%.+]]: bf16
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
// CHECK: return [[Y]]
%c = arith.constant 1.0 : bf16
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
// CHECK-LABEL: @chained
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
// CHECK: ...
[truncated]
|
To fix post-merge build error, suggest to set default attr to optioanl attr.
|
Thanks for fixing the build! One minor nit for the future: please include the revert reason in the PR description (just a one line to say "buildbots are broken" will do). |
Re the re-land ... maybe |
Thanks for your help! The new PR is here. |
Reverts #93443