Skip to content

Commit

Permalink
Patch necessary to make FP8 convolution compile with hiprtc (ROCm#2584)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Dec 7, 2023
1 parent 8ee1ad7 commit 7ae1553
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/kernels/gpu_reference_kernel/fp8_naive_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ struct conditional<false, X, Y>

template <bool predicate, typename X, typename Y>
using conditional_t = typename conditional<predicate, X, Y>::type;

} // namespace std

#else
#include <cstdint> // int8_t, int16_t
#include <cmath> // float_t
#endif
#else // __HIPCC_RTC__
#include <limits>
#endif // __HIPCC_RTC__

#include <limits> // std::numeric_limits

#define MIOPEN_ENABLE_F8_DEVICE_CODE 1
#include "hip_float8.hpp"

Expand Down
8 changes: 4 additions & 4 deletions src/kernels/hip_f8_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8_no_range_reduce(T _x,
template <int wm, int we, typename T, bool negative_zero_nan, bool clip>
MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
{
constexpr bool is_half = std::is_same<T, half>::value;
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_half = __is_same_as(T, half);
constexpr bool is_float = __is_same_as(T, float);
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");

Expand Down Expand Up @@ -273,8 +273,8 @@ MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng)
template <int wm, int we, typename T, bool negative_zero_nan>
MIOPEN_HIP_HOST_DEVICE T cast_from_f8(uint8_t x)
{
constexpr bool is_half = std::is_same<T, half>::value;
constexpr bool is_float = std::is_same<T, float>::value;
constexpr bool is_half = __is_same_as(T, half);
constexpr bool is_float = __is_same_as(T, float);
static_assert(is_half || is_float, "only half and float are supported");

constexpr int weo = is_half ? 5 : 8;
Expand Down
56 changes: 41 additions & 15 deletions src/kernels/hip_float8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ inline MIOPEN_HIP_HOST_DEVICE bool get_hip_f8_bias_mode()
#endif
}

template <typename T>
class numeric_limits;

template <hip_f8_type T>
struct hip_f8
{
Expand Down Expand Up @@ -262,8 +265,7 @@ struct hip_f8

inline MIOPEN_HIP_HOST_DEVICE bool operator==(const hip_f8& rhs) const
{
if((rhs.is_zero() && this->is_zero()) ||
(fabs(rhs - *this) < std::numeric_limits<hip_f8<T>>::epsilon()))
if((rhs.is_zero() && this->is_zero()) || (this->data == rhs.data))
{
return true;
}
Expand Down Expand Up @@ -487,19 +489,6 @@ MIOPEN_HIP_HOST_DEVICE T F8_Max()
x.bits = 0x7F;
return x.value;
}
} // namespace miopen_f8

// define numeric limits for the new data type
namespace std {
inline bool isfinite(miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8> x) // NOLINT
{
return x.is_inf();
}

inline bool isfinite(miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8> x) // NOLINT
{
return x.is_inf();
}

template <>
class numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8>>
Expand Down Expand Up @@ -555,7 +544,44 @@ class numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8>>
}
};

} // namespace miopen_f8

#ifndef __HIPCC_RTC__
namespace std {
inline bool isfinite(miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8> x) // NOLINT
{
return x.is_inf();
}

inline bool isfinite(miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8> x) // NOLINT
{
return x.is_inf();
}

inline bool isnan(miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8> x) // NOLINT
{
return x.is_nan();
}

inline bool isnan(miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8> x) // NOLINT
{
return x.is_nan();
}

template <>
class numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8>>
: public miopen_f8::numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::fp8>>
{
};

template <>
class numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8>>
: public miopen_f8::numeric_limits<miopen_f8::hip_f8<miopen_f8::hip_f8_type::bf8>>
{
};

} // namespace std
#endif

template <miopen_f8::hip_f8_type T>
struct hip_f8x4
Expand Down

0 comments on commit 7ae1553

Please sign in to comment.