Skip to content

Commit

Permalink
Add ARM tolerances to exhaustive tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679641677
  • Loading branch information
Gregory Pataky authored and Google-ML-Automation committed Sep 27, 2024
1 parent ce7c012 commit b7dbeb6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 26 deletions.
76 changes: 50 additions & 26 deletions xla/tests/exhaustive/exhaustive_unary_complex_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,19 @@ UNARY_TEST_COMPLEX_64(Sqrt, {
Run(Sqrt, [](complex64 x) { return std::sqrt(x); }, error_spec_gen);
})

double RsqrtCpuGpuAbsErr(complex64 x) {
return std::sqrt(std::numeric_limits<float>::min());
template <typename NativeT, typename ComponentNativeT>
double RsqrtCpuGpuAbsErr(NativeT x) {
return std::sqrt(std::numeric_limits<ComponentNativeT>::min());
}

double RsqrtCpuGpuRelErr(complex64 x) {
template <typename NativeT, typename ComponentNativeT>
double RsqrtCpuGpuRelErr(NativeT x) {
// As noted above for Sqrt, the accuracy of sqrt degrades severely for
// inputs with inputs with subnormals entries.
constexpr double eps = std::numeric_limits<float>::epsilon();
constexpr double norm_min = std::numeric_limits<float>::min();
constexpr double denorm_min = std::numeric_limits<float>::denorm_min();
constexpr double eps = std::numeric_limits<ComponentNativeT>::epsilon();
constexpr double norm_min = std::numeric_limits<ComponentNativeT>::min();
constexpr double denorm_min =
std::numeric_limits<ComponentNativeT>::denorm_min();
if (std::abs(x) < norm_min) {
// Gradually loosen the relative tolerance as abs(x) becomes smaller
// than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min.
Expand All @@ -164,9 +167,16 @@ UNARY_TEST_COMPLEX_64(Rsqrt, {
if (IsCpu()) {
error_spec_gen = +[](complex64 x) {
return ErrorSpec::Builder()
.abs_err(RsqrtCpuGpuAbsErr(x))
.rel_err(RsqrtCpuGpuRelErr(x))
.abs_err(RsqrtCpuGpuAbsErr<NativeT, ComponentNativeT>(x))
.rel_err(RsqrtCpuGpuRelErr<NativeT, ComponentNativeT>(x))
#ifdef __aarch64__
// TODO(b/365620546): ARM and x86 handle complex(inf, nan)
// differently.
.skip_comparison(x.real() == 0.0f ||
(std::isinf(x.real()) && std::isnan(x.imag())))
#else
.skip_comparison(x.real() == 0.0f)
#endif
.strict_signed_zeros(false)
.build();
};
Expand All @@ -175,8 +185,8 @@ UNARY_TEST_COMPLEX_64(Rsqrt, {
if (IsGpu()) {
error_spec_gen = +[](complex64 x) {
return ErrorSpec::Builder()
.abs_err(RsqrtCpuGpuAbsErr(x))
.rel_err(RsqrtCpuGpuRelErr(x))
.abs_err(RsqrtCpuGpuAbsErr<NativeT, ComponentNativeT>(x))
.rel_err(RsqrtCpuGpuRelErr<NativeT, ComponentNativeT>(x))
.strict_signed_zeros(false)
.build();
};
Expand Down Expand Up @@ -286,24 +296,38 @@ UNARY_TEST_COMPLEX_128(Sqrt, {
})

UNARY_TEST_COMPLEX_128(Rsqrt, {
ErrorSpecGen error_spec_gen = +[](complex128 x) {
// As noted above for Sqrt, the accuracy of sqrt degrades severely for
// inputs with inputs with subnormals entries.
constexpr double norm_min = std::numeric_limits<double>::min();
constexpr double denorm_min = std::numeric_limits<double>::denorm_min();
if (std::abs(x) < norm_min) {
// Gradually loosen the relative tolerance as abs(x) becomes smaller
// than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min.
ErrorSpecGen error_spec_gen = +[](complex128) {
return ErrorSpec::Builder().strict_signed_zeros().build();
};

if (IsCpu()) {
error_spec_gen = +[](complex128 x) {
return ErrorSpec::Builder()
.abs_err(std::sqrt(std::numeric_limits<double>::min()))
.rel_err(10 * denorm_min / std::abs(x))
.abs_err(RsqrtCpuGpuAbsErr<NativeT, ComponentNativeT>(x))
.rel_err(RsqrtCpuGpuRelErr<NativeT, ComponentNativeT>(x))
#ifdef __aarch64__
// TODO(b/365620546): ARM and x86 handle complex(inf, nan)
// differently.
.skip_comparison(x.real() == 0.0f ||
(std::isinf(x.real()) && std::isnan(x.imag())))
#else
.skip_comparison(x.real() == 0.0f)
#endif
.strict_signed_zeros(false)
.build();
}
return ErrorSpec::Builder()
.abs_err(std::sqrt(std::numeric_limits<double>::min()))
.rel_err(50 * std::numeric_limits<double>::epsilon())
.build();
};
};
}

if (IsGpu()) {
error_spec_gen = +[](complex128 x) {
return ErrorSpec::Builder()
.abs_err(RsqrtCpuGpuAbsErr<NativeT, ComponentNativeT>(x))
.rel_err(RsqrtCpuGpuRelErr<NativeT, ComponentNativeT>(x))
.strict_signed_zeros(false)
.build();
};
}

Run(
Rsqrt, [](complex128 x) { return complex128(1, 0) / std::sqrt(x); },
error_spec_gen);
Expand Down
28 changes: 28 additions & 0 deletions xla/tests/exhaustive/exhaustive_unary_test_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ UNARY_TEST(Sin, {
NativeT eps = std::numeric_limits<NativeT>::epsilon();
return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build();
})
.CpuArmError(+[](NativeT val) {
// Flushes subnormals and minimum positive output to 0.
NativeT output = static_cast<NativeT>(std::sin(val));
// TODO(b/365622116): Understand why ARM flushes these but x86 doesn't.
if (IsSubnormalOrMinNormal(output)) {
return ErrorSpec::Builder()
.abs_err(std::numeric_limits<NativeT>::min())
.build();
}

// This error spec corresponds to a maximum relative error of 2 ULP.
NativeT eps = std::numeric_limits<NativeT>::epsilon();
return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build();
})
.OutputRangeCheck(
+[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); })
.Run();
Expand All @@ -222,6 +236,20 @@ UNARY_TEST(Tan, {
NativeT eps = std::numeric_limits<NativeT>::epsilon();
return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build();
})
.CpuArmError(+[](NativeT val) {
// Flushes positive subnormals and minimum positive output to 0.
NativeT output = static_cast<NativeT>(std::tan(val));
// TODO(b/365622116): Understand why ARM flushes these but x86 doesn't.
if (IsSubnormalOrMinNormal(output)) {
return ErrorSpec::Builder()
.abs_err(std::numeric_limits<NativeT>::min())
.build();
}

// This error spec corresponds to a maximum relative error of 4 ULP.
NativeT eps = std::numeric_limits<NativeT>::epsilon();
return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build();
})
.Run();
})

Expand Down

0 comments on commit b7dbeb6

Please sign in to comment.