Skip to content

Commit

Permalink
PR #10503: Fix log1p inaccuracies on complex inputs with large absolu…
Browse files Browse the repository at this point in the history
…te values.

Imported from GitHub PR #10503

As in the title.

Tests and improvement reports are in jax-ml/jax#20144.

Accuracy tests are enabled in jax-ml/jax#20436
Copybara import of the project:

--
2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4
PiperOrigin-RevId: 621917683
  • Loading branch information
pearu authored and copybara-github committed Apr 4, 2024
1 parent 2421769 commit 2a29934
Show file tree
Hide file tree
Showing 6 changed files with 1,859 additions and 12 deletions.
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 253
_version = 254

# Version number for MLIR:Python components.
mlir_api_version = 55
Expand Down
66 changes: 55 additions & 11 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -974,21 +974,65 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
return EmitComplexLog(op, operand_value);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
//
// that is accurate only when |a| is relatively small while
// large |a| and |b| lead to multiplication overflow in the real
// part.
//
// The following expression for the real part:
//
// log1p(a+bi).real = log(hypot(a+1, b))
// = log(max(|a+1|, |b|) * sqrt(1 + (min(|a+1|, |b|) /
// max(|a+1|, b))^2)) [to fix overflow for maximal values
// of |a+1| and |b|] = log(max(|a+1|, |b|)) + log(sqrt(1
// + (min(|a+1|, |b|) / max(|a+1|, b))^2)) =
// log(max(|a+1|, |b|)) + 0.5*log1p((min(|a+1|, |b|) /
// max(|a+1|, b))^2) [to fix inaccuracies for small a,
// we'll use log1p] = log1p((1 + a > |b| ? a : max(|a+1|,
// |b|) - 1) + 0.5*log1p((min(|a+1|, |b|) / max(|a+1|,
// b))^2)
//
// is accurate on the whole complex plane except when |b| is
// small and a is very close to -|b|^2/2 that leads to
// substraction errors when adding the two log1p values as in
// log1p(-|b|^2) + log1p(|b|^2)
// TODO: improve the accuracy for the case above.

auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
auto a_plus_one = FAdd(a, one);
auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle,
EmitAtan2(component_type, b, a_plus_one, ""));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
auto half = llvm::ConstantFP::get(llvm_ty, 0.5);

auto a1 = FAdd(a, one);
auto abs_a1 = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a1},
{llvm_ty}, b_);
auto abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b},
{llvm_ty}, b_);

auto max_abs_of_a1_and_b = EmitFloatMax(abs_a1, abs_b, "");
auto min_abs_of_a1_and_b = EmitFloatMin(abs_a1, abs_b, "");

auto max_abs_of_a1_and_b_minus_one =
Select(FCmpOGT(a1, abs_b), a, FSub(max_abs_of_a1_and_b, one));
auto min_max_ratio = FDiv(min_abs_of_a1_and_b, max_abs_of_a1_and_b);

TF_ASSIGN_OR_RETURN(
auto log_of_max_abs_of_a1_and_b,
EmitLog1p(component_type, max_abs_of_a1_and_b_minus_one));
TF_ASSIGN_OR_RETURN(
auto log_of_sqrt_part,
EmitLog1p(component_type, FMul(min_max_ratio, min_max_ratio)));

auto r = FAdd(FMul(half, log_of_sqrt_part), log_of_max_abs_of_a1_and_b);
auto real_part = Select(FCmpUNO(r, r), min_abs_of_a1_and_b,
r); // handles nan and inf values correctly

TF_ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, ""));
return EmitComposeComplex(op, real_part, imag_part);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
Expand Down
23 changes: 23 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,29 @@ xla_test(
],
)

xla_test(
name = "complex_unary_op_test",
srcs = [
"complex_unary_op_samples.h",
"complex_unary_op_test.cc",
],
backends = [
"cpu",
"gpu",
],
deps = [
":client_library_test_base",
":literal_test_util",
":test_macros_header",
":xla_internal_test_main",
"//xla:xla_data_proto_cc",
"//xla/client:global_data",
"//xla/client:local_client",
"//xla/client:xla_builder",
"@tsl//tsl/platform:test",
],
)

xla_test(
name = "scalar_computations_test",
srcs = ["scalar_computations_test.cc"],
Expand Down
Loading

0 comments on commit 2a29934

Please sign in to comment.