Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix complex log1p accuracy with large abs values. #88260

Merged
merged 1 commit into from
Apr 10, 2024
Merged

Conversation

jreiffers
Copy link
Member

@jreiffers jreiffers commented Apr 10, 2024

This ports openxla/xla#10503 by @pearu. The new implementation matches mpmath's results for most inputs, see caveats in the linked pull request. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.

This ports openxla/xla#10503 by @pearu. In addition
to the filecheck test here, the accuracy was tested with XLA's
complex_unary_op_test and its MLIR emitters.
@jreiffers jreiffers requested a review from akuegel April 10, 2024 12:10
@llvmbot llvmbot added the mlir label Apr 10, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-mlir

Author: Johannes Reifferscheid (jreiffers)

Changes

This ports openxla/xla#10503 by @pearu. The new implementation matches mpmath's results for most inputs, see caveates in the linked pull request. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.


Full diff: https://github.com/llvm/llvm-project/pull/88260.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+26-24)
  • (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+31-17)
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9c3c4d96a301ef..0aa1de5fa5d9a1 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -570,37 +570,39 @@ struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
     mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
 
-    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+    Value real = b.create<complex::ReOp>(adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(adaptor.getComplex());
 
     Value half = b.create<arith::ConstantOp>(elementType,
                                              b.getFloatAttr(elementType, 0.5));
     Value one = b.create<arith::ConstantOp>(elementType,
                                             b.getFloatAttr(elementType, 1));
-    Value two = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 2));
-
-    // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
-    // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
-    // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
-    Value sumSq = b.create<arith::MulFOp>(real, real, fmf.getValue());
-    sumSq = b.create<arith::AddFOp>(
-        sumSq, b.create<arith::MulFOp>(real, two, fmf.getValue()),
-        fmf.getValue());
-    sumSq = b.create<arith::AddFOp>(
-        sumSq, b.create<arith::MulFOp>(imag, imag, fmf.getValue()),
-        fmf.getValue());
-    Value logSumSq =
-        b.create<math::Log1pOp>(elementType, sumSq, fmf.getValue());
-    Value resultReal = b.create<arith::MulFOp>(logSumSq, half, fmf.getValue());
-
-    Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf.getValue());
-
-    Value resultImag =
-        b.create<math::Atan2Op>(elementType, imag, realPlusOne, fmf.getValue());
+    Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
+    Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
+    Value absImag = b.create<math::AbsFOp>(imag, fmf);
+
+    Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
+    Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
+
+    Value maxAbsOfRealPlusOneAndImagMinusOne = b.create<arith::SelectOp>(
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, realPlusOne, absImag,
+                                fmf),
+        real, b.create<arith::SubFOp>(maxAbs, one, fmf));
+    Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
+    Value logOfMaxAbsOfRealPlusOneAndImag =
+        b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
+    Value logOfSqrtPart = b.create<math::Log1pOp>(
+        b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
+    Value r = b.create<arith::AddFOp>(
+        b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
+        logOfMaxAbsOfRealPlusOneAndImag, fmf);
+    Value resultReal = b.create<arith::SelectOp>(
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
+        r);
+    Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index f5d9499eadda48..43918904a09f40 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -300,15 +300,22 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] : f32
-// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] : f32
-// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] : f32
-// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] : f32
-// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] : f32
-// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] : f32
 // CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] : f32
+// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] : f32
+// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 : f32
+// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
+// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] : f32
+// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] : f32
+// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] : f32
+// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] : f32
+// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] : f32
+// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
@@ -963,15 +970,22 @@ func.func @complex_log1p_with_fmf(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[ONE_HALF:.*]] = arith.constant 5.000000e-01 : f32
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
-// CHECK: %[[SQ_SUM_0:.*]] = arith.mulf %[[REAL]], %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[TWO_REAL:.*]] = arith.mulf %[[REAL]], %[[TWO]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_SUM_1:.*]] = arith.addf %[[SQ_SUM_0]], %[[TWO_REAL]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_IMAG:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[SQ_SUM_2:.*]] = arith.addf %[[SQ_SUM_1]], %[[SQ_IMAG]] fastmath<nnan,contract> : f32
-// CHECK: %[[LOG_SQ_SUM:.*]] = math.log1p %[[SQ_SUM_2]] fastmath<nnan,contract> : f32
-// CHECK: %[[RESULT_REAL:.*]] = arith.mulf %[[LOG_SQ_SUM]], %[[ONE_HALF]] fastmath<nnan,contract> : f32
-// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract> : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = arith.addf %[[REAL]], %[[ONE]] fastmath<nnan,contract>  : f32
+// CHECK: %[[ABS_REAL_PLUS_ONE:.*]] = math.absf %[[REAL_PLUS_ONE]] fastmath<nnan,contract>  : f32
+// CHECK: %[[ABS_IMAG:.*]] = math.absf %[[IMAG]] fastmath<nnan,contract>  : f32
+// CHECK: %[[MAX:.*]] = arith.maximumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
+// CHECK: %[[MIN:.*]] = arith.minimumf %[[ABS_REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
+// CHECK: %[[CMPF:.*]] = arith.cmpf ogt, %[[REAL_PLUS_ONE]], %[[ABS_IMAG]] fastmath<nnan,contract>  : f32
+// CHECK: %[[MAX_MINUS_ONE:.*]] = arith.subf %[[MAX]], %cst_0 fastmath<nnan,contract>  : f32
+// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %0, %[[MAX_MINUS_ONE]] : f32
+// CHECK: %[[MIN_MAX_RATIO:.*]] = arith.divf %[[MIN]], %[[MAX]] fastmath<nnan,contract>  : f32
+// CHECK: %[[LOG_1:.*]] = math.log1p %[[SELECT]] fastmath<nnan,contract> : f32
+// CHECK: %[[RATIO_SQ:.*]] = arith.mulf %[[MIN_MAX_RATIO]], %[[MIN_MAX_RATIO]] fastmath<nnan,contract>  : f32
+// CHECK: %[[LOG_SQ:.*]] = math.log1p %[[RATIO_SQ]] fastmath<nnan,contract>  : f32
+// CHECK: %[[HALF_LOG_SQ:.*]] = arith.mulf %cst, %[[LOG_SQ]] fastmath<nnan,contract>  : f32
+// CHECK: %[[R:.*]] = arith.addf %[[HALF_LOG_SQ]], %[[LOG_1]] fastmath<nnan,contract>  : f32
+// CHECK: %[[ISNAN:.*]] = arith.cmpf uno, %[[R]], %[[R]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[ISNAN]], %[[MIN]], %[[R]] : f32
 // CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG]], %[[REAL_PLUS_ONE]] fastmath<nnan,contract> : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>

@jreiffers jreiffers merged commit 49ef12a into llvm:main Apr 10, 2024
5 of 6 checks passed
joker-eph added a commit that referenced this pull request Apr 10, 2024
joker-eph added a commit that referenced this pull request Apr 10, 2024
@joker-eph
Copy link
Collaborator

Hey @jreiffers : seems like the test fails on the gcc7 bot: https://lab.llvm.org/buildbot/#/builders/264/builds/9257/steps/6/logs/FAIL__MLIR__convert-to-standard_mlir ; I reverted in the meantime, feel free to reland when you have a fix!

@jreiffers
Copy link
Member Author

Ugh. Sorry about that and thanks.

jreiffers added a commit to jreiffers/llvm-project that referenced this pull request Apr 11, 2024
This ports openxla/xla#10503 by @pearu. In addition
to the filecheck test here, the accuracy was tested with XLA's
complex_unary_op_test and its MLIR emitters.

This is a fixed version of
llvm#88260. The previous version
relied on implementation-specific behavior in the order of evaluation of
maxAbsOfRealPlusOneAndImagMinusOne's operands.
jreiffers added a commit that referenced this pull request Apr 11, 2024
This ports openxla/xla#10503 by @pearu. In addition to the filecheck
test here, the accuracy was tested with XLA's complex_unary_op_test and
its MLIR emitters.

This is a fixed version of
#88260. The previous version
relied on implementation-specific behavior in the order of evaluation of
maxAbsOfRealPlusOneAndImagMinusOne's operands.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants