diff --git a/src/kernels/gpu_reference_kernel/fp8_naive_conv.cpp b/src/kernels/gpu_reference_kernel/fp8_naive_conv.cpp index 3b4eabecfb..f24a2d8813 100644 --- a/src/kernels/gpu_reference_kernel/fp8_naive_conv.cpp +++ b/src/kernels/gpu_reference_kernel/fp8_naive_conv.cpp @@ -63,15 +63,17 @@ struct conditional template using conditional_t = typename conditional::type; + } // namespace std + #else #include // int8_t, int16_t #include // float_t #endif +#else // __HIPCC_RTC__ +#include #endif // __HIPCC_RTC__ -#include // std::numeric_limits - #define MIOPEN_ENABLE_F8_DEVICE_CODE 1 #include "hip_float8.hpp" diff --git a/src/kernels/hip_f8_impl.hpp b/src/kernels/hip_f8_impl.hpp index 03b7f901bf..c8d49cd474 100644 --- a/src/kernels/hip_f8_impl.hpp +++ b/src/kernels/hip_f8_impl.hpp @@ -87,8 +87,8 @@ MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8_no_range_reduce(T _x, template MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) { - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = __is_same_as(T, half); + constexpr bool is_float = __is_same_as(T, float); static_assert(wm + we == 7, "wm+we==7"); static_assert(is_half || is_float, "Only half and float can be cast to f8"); @@ -273,8 +273,8 @@ MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) template MIOPEN_HIP_HOST_DEVICE T cast_from_f8(uint8_t x) { - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = __is_same_as(T, half); + constexpr bool is_float = __is_same_as(T, float); static_assert(is_half || is_float, "only half and float are supported"); constexpr int weo = is_half ? 5 : 8; diff --git a/src/kernels/hip_float8.hpp b/src/kernels/hip_float8.hpp index a9b2a559a8..2947d6d713 100644 --- a/src/kernels/hip_float8.hpp +++ b/src/kernels/hip_float8.hpp @@ -83,6 +83,9 @@ inline MIOPEN_HIP_HOST_DEVICE bool get_hip_f8_bias_mode() #endif } +template +class numeric_limits; + template struct hip_f8 { @@ -262,8 +265,7 @@ struct hip_f8 inline MIOPEN_HIP_HOST_DEVICE bool operator==(const hip_f8& rhs) const { - if((rhs.is_zero() && this->is_zero()) || - (fabs(rhs - *this) < std::numeric_limits>::epsilon())) + if((rhs.is_zero() && this->is_zero()) || (this->data == rhs.data)) { return true; } @@ -487,19 +489,6 @@ MIOPEN_HIP_HOST_DEVICE T F8_Max() x.bits = 0x7F; return x.value; } -} // namespace miopen_f8 - -// define numeric limits for the new data type -namespace std { -inline bool isfinite(miopen_f8::hip_f8 x) // NOLINT -{ - return x.is_inf(); -} - -inline bool isfinite(miopen_f8::hip_f8 x) // NOLINT -{ - return x.is_inf(); -} template <> class numeric_limits> @@ -555,7 +544,44 @@ class numeric_limits> } }; +} // namespace miopen_f8 + +#ifndef __HIPCC_RTC__ +namespace std { +inline bool isfinite(miopen_f8::hip_f8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isfinite(miopen_f8::hip_f8 x) // NOLINT +{ + return x.is_inf(); +} + +inline bool isnan(miopen_f8::hip_f8 x) // NOLINT +{ + return x.is_nan(); +} + +inline bool isnan(miopen_f8::hip_f8 x) // NOLINT +{ + return x.is_nan(); +} + +template <> +class numeric_limits> + : public miopen_f8::numeric_limits> +{ +}; + +template <> +class numeric_limits> + : public miopen_f8::numeric_limits> +{ +}; + } // namespace std +#endif template struct hip_f8x4