Skip to content

Commit

Permalink
Fix log1p inaccuracies on complex inputs with large absolute values.
Browse files Browse the repository at this point in the history
  • Loading branch information
pearu committed Apr 3, 2024
1 parent f8e10a9 commit 2b8f253
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 253
_version = 254

# Version number for MLIR:Python components.
mlir_api_version = 55
Expand Down
55 changes: 44 additions & 11 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -974,21 +974,54 @@ absl::StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
return EmitComplexLog(op, operand_value);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
// log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
// log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
//
// that is accurate only when |a| is relatively small while
// large |a| and |b| lead to multiplication overflow in the real
// part.
//
// The following expression for the real part:
//
// log1p(a+bi).real = log(hypot(a+1, b))
// = log(max(|a+1|, |b|) * sqrt(1 + (min(|a+1|, |b|) / max(|a+1|, b))^2))
// [to fix overflow for maximal values of |a+1| and |b|]
// = log(max(|a+1|, |b|)) + log(sqrt(1 + (min(|a+1|, |b|) / max(|a+1|, b))^2))
// = log(max(|a+1|, |b|)) + 0.5*log1p((min(|a+1|, |b|) / max(|a+1|, b))^2)
// [to fix inaccuracies for small a, we'll use log1p]
// = log1p((1 + a > |b| ? a : max(|a+1|, |b|) - 1) + 0.5*log1p((min(|a+1|, |b|) / max(|a+1|, b))^2)
//
// is accurate on the whole complex plane except when |b| is
// small and a is very close to -|b|^2/2 that leads to
// substraction errors when adding the two log1p values as in
// log1p(-|b|^2) + log1p(|b|^2)
// TODO: improve the accuracy for the case above.

auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
auto two = llvm::ConstantFP::get(llvm_ty, 2.0);
auto a_plus_one = FAdd(a, one);
auto sum_sq = FAdd(FAdd(FMul(a, a), FMul(two, a)), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog1p(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle,
EmitAtan2(component_type, b, a_plus_one, ""));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
auto half = llvm::ConstantFP::get(llvm_ty, 0.5);

auto a1 = FAdd(a, one);
auto abs_a1 = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {a1}, {llvm_ty}, b_);
auto abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {b}, {llvm_ty}, b_);

auto max_abs_of_a1_and_b = EmitFloatMax(abs_a1, abs_b, "");
auto min_abs_of_a1_and_b = EmitFloatMin(abs_a1, abs_b, "");

auto max_abs_of_a1_and_b_minus_one = Select(FCmpOGT(a1, abs_b), a, FSub(max_abs_of_a1_and_b, one));
auto min_max_ratio = FDiv(min_abs_of_a1_and_b, max_abs_of_a1_and_b);

TF_ASSIGN_OR_RETURN(auto log_of_max_abs_of_a1_and_b, EmitLog1p(component_type, max_abs_of_a1_and_b_minus_one));
TF_ASSIGN_OR_RETURN(auto log_of_sqrt_part, EmitLog1p(component_type, FMul(min_max_ratio, min_max_ratio)));

auto r = FAdd(FMul(half, log_of_sqrt_part), log_of_max_abs_of_a1_and_b);
auto real_part = Select(FCmpUNO(r, r), min_abs_of_a1_and_b, r); // handles nan and inf values correctly

TF_ASSIGN_OR_RETURN(auto imag_part, EmitAtan2(component_type, b, a1, ""));
return EmitComposeComplex(op, real_part, imag_part);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
Expand Down

0 comments on commit 2b8f253

Please sign in to comment.