From b7dbeb69d0e12b4f3f196c28911f39da1c5a54e3 Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Fri, 27 Sep 2024 10:27:44 -0700 Subject: [PATCH] Add ARM tolerances to exhaustive tests PiperOrigin-RevId: 679641677 --- .../exhaustive_unary_complex_test.cc | 76 ++++++++++++------- .../exhaustive_unary_test_functions.cc | 28 +++++++ 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index f93f5e31f544b..d652e56bd1b6f 100644 --- a/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -138,16 +138,19 @@ UNARY_TEST_COMPLEX_64(Sqrt, { Run(Sqrt, [](complex64 x) { return std::sqrt(x); }, error_spec_gen); }) -double RsqrtCpuGpuAbsErr(complex64 x) { - return std::sqrt(std::numeric_limits::min()); +template +double RsqrtCpuGpuAbsErr(NativeT x) { + return std::sqrt(std::numeric_limits::min()); } -double RsqrtCpuGpuRelErr(complex64 x) { +template +double RsqrtCpuGpuRelErr(NativeT x) { // As noted above for Sqrt, the accuracy of sqrt degrades severely for // inputs with inputs with subnormals entries. - constexpr double eps = std::numeric_limits::epsilon(); - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); + constexpr double eps = std::numeric_limits::epsilon(); + constexpr double norm_min = std::numeric_limits::min(); + constexpr double denorm_min = + std::numeric_limits::denorm_min(); if (std::abs(x) < norm_min) { // Gradually loosen the relative tolerance as abs(x) becomes smaller // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. @@ -164,9 +167,16 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { if (IsCpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) +#ifdef __aarch64__ + // TODO(b/365620546): ARM and x86 handle complex(inf, nan) + // differently. + .skip_comparison(x.real() == 0.0f || + (std::isinf(x.real()) && std::isnan(x.imag()))) +#else .skip_comparison(x.real() == 0.0f) +#endif .strict_signed_zeros(false) .build(); }; @@ -175,8 +185,8 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { if (IsGpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) .strict_signed_zeros(false) .build(); }; @@ -286,24 +296,38 @@ UNARY_TEST_COMPLEX_128(Sqrt, { }) UNARY_TEST_COMPLEX_128(Rsqrt, { - ErrorSpecGen error_spec_gen = +[](complex128 x) { - // As noted above for Sqrt, the accuracy of sqrt degrades severely for - // inputs with inputs with subnormals entries. - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); - if (std::abs(x) < norm_min) { - // Gradually loosen the relative tolerance as abs(x) becomes smaller - // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. + ErrorSpecGen error_spec_gen = +[](complex128) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu()) { + error_spec_gen = +[](complex128 x) { return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(10 * denorm_min / std::abs(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) +#ifdef __aarch64__ + // TODO(b/365620546): ARM and x86 handle complex(inf, nan) + // differently. + .skip_comparison(x.real() == 0.0f || + (std::isinf(x.real()) && std::isnan(x.imag()))) +#else + .skip_comparison(x.real() == 0.0f) +#endif + .strict_signed_zeros(false) .build(); - } - return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(50 * std::numeric_limits::epsilon()) - .build(); - }; + }; + } + + if (IsGpu()) { + error_spec_gen = +[](complex128 x) { + return ErrorSpec::Builder() + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) + .strict_signed_zeros(false) + .build(); + }; + } + Run( Rsqrt, [](complex128 x) { return complex128(1, 0) / std::sqrt(x); }, error_spec_gen); diff --git a/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index 9320e0c9ccedb..5baa12f15ca45 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -211,6 +211,20 @@ UNARY_TEST(Sin, { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); }) + .CpuArmError(+[](NativeT val) { + // Flushes subnormals and minimum positive output to 0. + NativeT output = static_cast(std::sin(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + + // This error spec corresponds to a maximum relative error of 2 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); + }) .OutputRangeCheck( +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) .Run(); @@ -222,6 +236,20 @@ UNARY_TEST(Tan, { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); }) + .CpuArmError(+[](NativeT val) { + // Flushes positive subnormals and minimum positive output to 0. + NativeT output = static_cast(std::tan(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + + // This error spec corresponds to a maximum relative error of 4 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); + }) .Run(); })