Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" #95344

Merged
merged 1 commit into from
Jun 13, 2024

Conversation

crazydemo
Copy link
Contributor

Reverts #93443

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Ivy Zhang (crazydemo)

Changes

Reverts llvm/llvm-project#93443


Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95344.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+2-14)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (-14)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-8)
  • (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+3-8)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+17-17)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-137)
  • (modified) mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir (+15-15)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c4471f9bc5af2..06fbdb7f2c4cb 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
 // ExtFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
   let summary = "cast from floating-point to wider floating-point";
   let description = [{
     Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,13 +1208,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
   }];
   let hasVerifier = 1;
   let hasFolder = 1;
-
-  let arguments = (ins FloatLike:$in, DefaultValuedAttr<
-                         Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
-  let results = (outs FloatLike:$out);
-
-  let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
-                          attr-dict `:` type($in) `to` type($out) }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1253,11 +1246,8 @@ def Arith_TruncFOp :
     Arith_Op<"truncf",
       [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
        DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
-       DeclareOpInterfaceMethods<ArithFastMathInterface>,
        DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins FloatLike:$in,
-                   DefaultValuedAttr<
-                      Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
                    OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
     Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
@@ -1277,9 +1267,7 @@ def Arith_TruncFOp :
 
   let hasFolder = 1;
   let hasVerifier = 1;
-  let assemblyFormat = [{ $in ($roundingmode^)?
-                          (`fastmath` `` $fastmath^)?
-                          attr-dict `:` type($in) `to` type($out) }];
+  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 291f6e5424ba5..2f6647a2a27b1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1390,20 +1390,6 @@ LogicalResult arith::ExtSIOp::verify() {
 /// Fold extension of float constants when there is no information loss due the
 /// difference in fp semantics.
 OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
-  if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
-    if (truncFOp.getOperand().getType() == getType()) {
-      arith::FastMathFlags truncFMF = truncFOp.getFastmath();
-      bool isTruncContract =
-          bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
-      arith::FastMathFlags extFMF = getFastmath();
-      bool isExtContract =
-          bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
-      if (isTruncContract && isExtContract) {
-        return truncFOp.getOperand();
-      }
-    }
-  }
-
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 8e1cb474feee7..4a50da3513f99 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
   SmallVector<Value> newResults(expandedOp->getResults());
   for (auto [res, oldType, newType] : llvm::zip_equal(
            MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
-    if (oldType != newType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      res = truncFOp.getResult();
-    }
+    if (oldType != newType)
+      res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
   }
   rewriter.replaceOp(op, newResults);
 }
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
   });
   converter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
+        return b.create<arith::ExtFOp>(loc, target, input);
       });
 }
 
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index 3d99f3033cf56..5998133b7eab8 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
   });
   typeConverter.addTargetMaterialization(
       [](OpBuilder &b, Type target, ValueRange input, Location loc) {
-        auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
-        extFOp.setFastmath(arith::FastMathFlags::contract);
-        return extFOp;
+        return b.create<arith::ExtFOp>(loc, target, input);
       });
 }
 
@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
   SmallVector<Value> results = (*legalized)->getResults();
   for (auto [result, newType, origType] : llvm::zip_equal(
            results, (*legalized)->getResultTypes(), op->getResultTypes())) {
-    if (newType != origType) {
-      auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
-      truncFOp.setFastmath(arith::FastMathFlags::contract);
-      result = truncFOp.getResult();
-    }
+    if (newType != origType)
+      result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
   rewriter.replaceOp(op, results);
   return success();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index cacdd801871fb..56ae930e6d627 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -162,11 +162,11 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext(%arg0 : f16, %arg1 : f32) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
   %0 = arith.extf %arg0: f16 to f32
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
   %1 = arith.extf %arg0: f16 to f64
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
+// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
   %2 = arith.extf %arg1: f32 to f64
   return
 }
@@ -174,11 +174,11 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fpext
 func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
   %0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
   %1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
-// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
+// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
   %2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
   return
 }
@@ -268,11 +268,11 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
   %0 = arith.truncf %arg0: f32 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
   %1 = arith.truncf %arg1: f64 to f16
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
   %2 = arith.truncf %arg1: f64 to f32
   return
 }
@@ -280,26 +280,26 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
 // Checking conversion of integer types to floating point.
 // CHECK-LABEL: @fptrunc
 func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
   %0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
   %1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
-// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
+// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
   %2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
   return
 }
 
 // CHECK-LABEL: experimental_constrained_fptrunc
 func.func @experimental_constrained_fptrunc(%arg0 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
   %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
   %1 = arith.truncf %arg0 downward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
   %2 = arith.truncf %arg0 upward : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
   %3 = arith.truncf %arg0 toward_zero : f64 to f32
-// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
   %4 = arith.truncf %arg0 to_nearest_away : f64 to f32
   return
 }
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3031,143 +3031,6 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
   return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
 }
 
-// CHECK-LABEL: @sequences_fastmath_contract
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
-  %0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
-  %1 = math.absf %0 : f32
-  %2 = arith.truncf %1 fastmath<contract> : f32 to bf16
-  %3 = arith.extf %2 fastmath<contract> : bf16 to f32
-  %4 = math.sin %3 : f32
-  %5 = arith.truncf %4 fastmath<contract> : f32 to bf16
-  return %5 : bf16
-}
-
-// CHECK-LABEL: @sequences_no_fastmath
-// CHECK-SAME: ([[ARG0:%.+]]: bf16)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
-// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : bf16
-func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
-  %0 = arith.extf %arg0 : bf16 to f32
-  %1 = math.absf %0 : f32
-  %2 = arith.truncf %1 : f32 to bf16
-  %3 = arith.extf %2 : bf16 to f32
-  %4 = math.sin %3 : f32
-  %5 = arith.truncf %4 : f32 to bf16
-  return %5 : bf16
-}
-
-// CHECK-LABEL: @eliminate_cast_to_f16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
-  %1 = arith.extf %0 fastmath<contract> : f16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @eliminate_cast_to_bf16
-// CHECK: return [[arg0:%.+]] : f32
-func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
-  %0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
-  %1 = arith.extf %0 fastmath<contract> : bf16 to f32
-  return %1 : f32
-}
-
-// CHECK-LABEL: @bf16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %5 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @f16_sin_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
-func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
-  return %5 : vector<32x32x32xf16>
-}
-
-// CHECK-LABEL: @bf16_branch_vector
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
-// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
-func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %7 = math.cos %3 : vector<32x32x32xf32>
-  %8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %10 = arith.addf %6, %9 : vector<32x32x32xf32>
-  %11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %11 : vector<32x32x32xbf16>
-}
-
-// CHECK-LABEL: @bf16_fma
-// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
-// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
-// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
-// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
-// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
-// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
-// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
-// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
-// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
-// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
-func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
-  %0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %1 = math.absf %0 : vector<32x32x32xf32>
-  %2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %4 = math.sin %3 : vector<32x32x32xf32>
-  %5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  %6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
-  %8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
-  %9 = arith.addf %8, %6 : vector<32x32x32xf32>
-  %10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
-  return %10 : vector<32x32x32xbf16>
-}
-
 {-#
   dialect_resources: {
     builtin: {
diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
index 99790cc45d490..a69ef131d8d47 100644
--- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
+++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 // CHECK-LABEL: @basic_expansion
 // CHECK-SAME: [[X:%.+]]: bf16
 // CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
 // CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
-// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
+// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
 // CHECK: return [[Y]]
   %c = arith.constant 1.0 : bf16
   %y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
 func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
 // CHECK-LABEL: @chained
 // CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
-// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
-// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
+// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
+// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
+// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
 // CHECK: ...
[truncated]

@crazydemo crazydemo requested review from kuhar and krzysz00 June 13, 2024 03:21
@crazydemo
Copy link
Contributor Author

To fix post-merge build error, suggest to set default attr to optioanl attr.

/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir --sparsifier="enable-runtime-library=true" | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-cpu-runner -e main -entry-point-result=void -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so,/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir --sparsifier=enable-runtime-library=true
executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-cpu-runner -e main -entry-point-result=void -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so,/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so
.---command stderr------------
| loc("<stdin>":488:45): error: #"arith"<"fastmath<none>"> : 'none' attribute created with unregistered dialect. If this is intended, please call allowUnregisteredDialects() on the MLIRContext, or use -allow-unregistered-dialect with the MLIR opt tool used
| could not parse the input IR
`-----------------------------
error: command failed with exit status: 1
executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
.---command stderr------------
| FileCheck error: '<stdin>' is empty.
| FileCheck command line:  /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_cast.mlir
`-----------------------------
error: command failed with exit status: 2

@crazydemo crazydemo merged commit f941908 into main Jun 13, 2024
8 of 10 checks passed
@crazydemo crazydemo deleted the revert-93443-zhangyan/independent_canonicalize branch June 13, 2024 03:23
@joker-eph
Copy link
Collaborator

Thanks for fixing the build!

One minor nit for the future: please include the revert reason in the PR description (just a one line to say "buildbots are broken" will do).

@krzysz00
Copy link
Contributor

Re the re-land ... maybe DefaultValuedOptionalAttr will be helpful?

@crazydemo
Copy link
Contributor Author

Re the re-land ... maybe DefaultValuedOptionalAttr will be helpful?

Thanks for your help! The new PR is here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants