-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix complex log1p accuracy with large abs values. #88260
Conversation
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.
@llvm/pr-subscribers-mlir Author: Johannes Reifferscheid (jreiffers) ChangesThis ports openxla/xla#10503 by @pearu. The new implementation matches mpmath's results for most inputs, see caveates in the linked pull request. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters. Full diff: https://github.com/llvm/llvm-project/pull/88260.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c3c4d96a301ef..0aa1de5fa5d9a1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -570,37 +570,39 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
ConversionPatternRewriter &rewriter) const override {
auto type = cast<ComplexType>(adaptor.getComplex().getType());
auto elementType = cast<FloatType>(type.getElementType());
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value real = b.create<complex::ReOp>(adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(adaptor.getComplex());
Value half = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 0.5));
Value one = b.create<arith::ConstantOp>(elementType,
b.getFloatAttr(elementType, 1));
- Value two = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 2));
-
- // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
- // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
- // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
- Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
- sumSq = b.create<arith::AddFOp>(
- sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
- fmf.getValue());
- sumSq = b.create<arith::AddFOp>(
- sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
- fmf.getValue());
- Value logSumSq =
- b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
- Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
-
- Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
-
- Value resultImag =
- b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
+ Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
+ Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
+ Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+ Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
+ Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
+
+ Value maxAbsOfRealPlusOneAndImagMinusOne = b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, realPlusOne, absImag,
+ fmf),
+ real, b.create<arith::SubFOp>(maxAbs, one, fmf));
+ Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
+ Value logOfMaxAbsOfRealPlusOneAndImag =
+ b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
+ Value logOfSqrtPart = b.create<math::Log1pOp>(
+ b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
+ Value r = b.create<arith::AddFOp>(
+ b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
+ logOfMaxAbsOfRealPlusOneAndImag, fmf);
+ Value resultReal = b.create<arith::SelectOp>(
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
+ r);
+ Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index f5d9499eadda48..43918904a09f40 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -300,15 +300,22 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
-// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
-// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
-// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
+// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 : f32
+// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
+// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] : f32
+// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] : f32
+// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] : f32
+// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] : f32
+// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
@@ -963,15 +970,22 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 fastmath<nnan,contract> : f32
+// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
+// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract> : f32
+// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] fastmath<nnan,contract> : f32
+// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] fastmath<nnan,contract> : f32
+// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] fastmath<nnan,contract> : f32
+// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] fastmath<nnan,contract> : f32
+// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
|
This reverts commit 49ef12a.
Reverts #88260 The test fails on the GCC7 buildbot.
Hey @jreiffers : seems like the test fails on the gcc7 bot: https://lab.llvm.org/buildbot/#/builders/264/builds/9257/steps/6/logs/FAIL__MLIR__convert-to-standard_mlir ; I reverted in the meantime, feel free to reland when you have a fix! |
Ugh. Sorry about that and thanks. |
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters. This is a fixed version of llvm#88260. The previous version relied on implementation-specific behavior in the order of evaluation of maxAbsOfRealPlusOneAndImagMinusOne's operands.
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters. This is a fixed version of #88260. The previous version relied on implementation-specific behavior in the order of evaluation of maxAbsOfRealPlusOneAndImagMinusOne's operands.
This ports openxla/xla#10503 by @pearu. The new implementation matches mpmath's results for most inputs, see caveats in the linked pull request. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.