Skip to content

Commit

Permalink
[SYCL] Fix sycl::vec::convert<> to allow conversion to and from `sy…
Browse files Browse the repository at this point in the history
…cl::vec` of `bfloat16` type to that of other data types (#14105)

Follow-up of and blocked by: #14085

After this change:
On host, conversion between `vec<bfloat16>` and `vec<float>` will happen
element-by-element. While on device, we'll use Spirv intrinsic
`OpConvertFToBF16INTEL` and `OpConvertBF16ToFINTEL`
(https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_bfloat16_conversion.asciidoc)
for vector conversion.
  • Loading branch information
uditagarwal97 authored Jun 21, 2024
1 parent f944ed6 commit 02c6bba
Show file tree
Hide file tree
Showing 7 changed files with 949 additions and 24 deletions.
8 changes: 8 additions & 0 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <sycl/half_type.hpp> // for BIsRepresentationT
#include <sycl/multi_ptr.hpp> // for multi_ptr, address_spa...

#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16 storage type.

#include <cstddef> // for byte
#include <cstdint> // for uint8_t
#include <limits> // for numeric_limits
Expand Down Expand Up @@ -386,7 +388,13 @@ template <typename T> auto convertToOpenCLType(T &&x) {
static_assert(sizeof(OpenCLType) == sizeof(T));
return static_cast<OpenCLType>(x);
} else if constexpr (is_bfloat16_v<no_ref>) {
// On host, don't interpret BF16 as uint16.
#ifdef __SYCL_DEVICE_ONLY__
using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT;
return sycl::bit_cast<OpenCLType>(x);
#else
return std::forward<T>(x);
#endif
} else if constexpr (std::is_floating_point_v<no_ref>) {
static_assert(std::is_same_v<no_ref, float> ||
std::is_same_v<no_ref, double>,
Expand Down
305 changes: 304 additions & 1 deletion sycl/include/sycl/detail/vector_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,100 @@
#include <sycl/detail/generic_type_traits.hpp> // for is_sigeninteger, is_s...
#include <sycl/exception.hpp> // for errc

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#ifndef __SYCL_DEVICE_ONLY__
#include <cfenv> // for fesetround, fegetround
#endif

#include <type_traits>

// Enable on only intel devices.
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
extern "C" {
// For converting BF16 to other types.
extern __DPCPP_SYCL_EXTERNAL float __imf_bfloat162float(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned int __imf_bfloat162uint_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat162ushort_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned long long
__imf_bfloat162ull_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL int __imf_bfloat162int_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat162short_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rd(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rn(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_ru(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL long long __imf_bfloat162ll_rz(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL short __imf_bfloat16_as_short(uint16_t x);
extern __DPCPP_SYCL_EXTERNAL unsigned short
__imf_bfloat16_as_ushort(uint16_t x);

// For converting other types to BF16.
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rd(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rn(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_ru(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_float2bfloat16_rz(float x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rd(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rn(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_ru(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort2bfloat16_rz(unsigned short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rd(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rn(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_ru(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_uint2bfloat16_rz(unsigned int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rd(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rn(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_ru(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ull2bfloat16_rz(unsigned long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rd(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rn(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_ru(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short2bfloat16_rz(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rd(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rn(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_ru(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_int2bfloat16_rz(int x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rd(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rn(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_ru(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_ll2bfloat16_rz(long long x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_double2bfloat16(double x);
extern __DPCPP_SYCL_EXTERNAL uint16_t __imf_short_as_bfloat16(short x);
extern __DPCPP_SYCL_EXTERNAL uint16_t
__imf_ushort_as_bfloat16(unsigned short x);
}
#endif // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))

namespace sycl {

enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 };
Expand All @@ -81,6 +169,10 @@ inline double trunc(double);
#endif
namespace detail {

template <typename FromT, typename ToT, sycl::rounding_mode RoundingMode,
int VecSize, typename NativeFromT, typename NativeToT>
NativeToT convertImpl(NativeFromT);

template <typename T, typename R>
using is_sint_to_sint =
std::bool_constant<is_sigeninteger_v<T> && is_sigeninteger_v<R>>;
Expand Down Expand Up @@ -123,6 +215,8 @@ using is_float_to_float =
std::bool_constant<detail::is_floating_point<T>::value &&
detail::is_floating_point<R>::value>;

using bfloat16 = sycl::ext::oneapi::bfloat16;

#ifndef __SYCL_DEVICE_ONLY__
template <typename From, typename To, int VecSize,
typename Enable = std::enable_if_t<VecSize == 1>>
Expand Down Expand Up @@ -196,8 +290,29 @@ template <typename From, typename To, int VecSize,
To ConvertFToU(From Value) {
return ConvertFToS<From, To, VecSize, Enable, roundingMode>(Value);
}
#else

template <typename NativeToT, sycl::rounding_mode RoundingMode>
inline NativeToT ConvertFromBF16Scalar(bfloat16 val) {
// On host, NativeBF16T is bfloat16. Convert BF16 to float losslessly.
float fval = static_cast<float>(val);

if constexpr (std::is_same_v<NativeToT, float>)
return fval;
else
// Convert float to the desired type.
return convertImpl<float, NativeToT, RoundingMode, 1, float, NativeToT>(
fval);
}

template <typename NativeFromT, sycl::rounding_mode RoundingMode>
bfloat16 ConvertToBF16Scalar(NativeFromT val) {

constexpr int rm = static_cast<int>(RoundingMode);
return sycl::ext::oneapi::detail::ConvertToBfloat16::
getBfloat16WithRoundingMode<NativeFromT, rm>(val);
}

#else
// Bunch of helpers to "specialize" each template for its own destination type
// and vector size.

Expand Down Expand Up @@ -498,8 +613,188 @@ __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE(double)
#undef __SYCL_FLOAT_FLOAT_CONVERT
#undef __SYCL_FLOAT_FLOAT_CONVERT_FOR_TYPE

template <typename NativeBFT, typename NativeFloatT, int VecSize>
inline NativeFloatT ConvertBF16ToFVec(NativeBFT vec) {
bfloat16 *src = sycl::bit_cast<bfloat16 *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
float dst[AdjustedSize];
sycl::ext::oneapi::detail::BF16VecToFloatVec<VecSize>(src, dst);

return sycl::bit_cast<NativeFloatT>(dst);
}

template <typename NativeFloatT, typename NativeBFT, int VecSize>
inline NativeBFT ConvertFToBF16Vec(NativeFloatT vec) {
float *src = sycl::bit_cast<float *>(&vec);

// OpenCL vector of 3 elements is aligned to 4 multiplied by
// the size of data type.
constexpr int AdjustedSize = (VecSize == 3) ? 4 : VecSize;
bfloat16 dst[AdjustedSize];

sycl::ext::oneapi::detail::FloatVecToBF16Vec<VecSize>(src, dst);
return sycl::bit_cast<NativeBFT>(dst);
}

/* Emit _imf_* funcs only on Intel hardware. */
#if defined(__SPIR__) || defined(__SPIRV__)
#define EXPAND_BF16_ROUNDING_MODE(type, type_str, rmode, rmode_str) \
template <typename NativeToT, sycl::rounding_mode RoundingMode> \
std::enable_if_t<(std::is_same_v<NativeToT, type> && RoundingMode == rmode), \
NativeToT> \
ConvertFromBF16Scalar(uint16_t val) { \
return __imf_bfloat162##type_str##_##rmode_str(val); \
} \
template <typename NativeFromT, sycl::rounding_mode RoundingMode> \
std::enable_if_t< \
(std::is_same_v<NativeFromT, type> && RoundingMode == rmode), uint16_t> \
ConvertToBF16Scalar(NativeFromT val) { \
return __imf_##type_str##2bfloat16_##rmode_str(val); \
}

#else // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))
// On non-Intel HWs, convert BF16 to float (losslessly) and convert float
// to the desired type.
#define EXPAND_BF16_ROUNDING_MODE(type, type_str, rmode, rmode_str) \
template <typename NativeToT, sycl::rounding_mode RoundingMode> \
std::enable_if_t<(std::is_same_v<NativeToT, type> && RoundingMode == rmode), \
NativeToT> \
ConvertFromBF16Scalar(uint16_t val) { \
bfloat16 bfval = sycl::bit_cast<bfloat16>(val); \
float fval = static_cast<float>(bfval); \
return convertImpl<fval, NativeToT, RoundingMode, 1, float, NativeToT>( \
fval); \
} \
template <typename NativeFromT, sycl::rounding_mode RoundingMode> \
std::enable_if_t< \
(std::is_same_v<NativeFromT, type> && RoundingMode == rmode), uint16_t> \
ConvertToBF16Scalar(NativeFromT val) { \
constexpr int rm = static_cast<int>(RoundingMode); \
bfloat16 bfval = sycl::ext::oneapi::detail::ConvertToBfloat16:: \
getBfloat16WithRoundingMode<NativeFromT, rm>(val); \
return sycl::bit_cast<uint16_t>(bfval); \
}
#endif // __SYCL_DEVICE_ONLY__ && (defined(__SPIR__) || defined(__SPIRV__))

#define EXPAND_BF16_TYPE(type, type_str) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::automatic, \
rn) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rte, rn) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtp, ru) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtn, rd) \
EXPAND_BF16_ROUNDING_MODE(type, type_str, sycl::rounding_mode::rtz, rz)

EXPAND_BF16_TYPE(uint, uint)
EXPAND_BF16_TYPE(int, int)
EXPAND_BF16_TYPE(ushort, ushort)
EXPAND_BF16_TYPE(short, short)
EXPAND_BF16_TYPE(long, ll)
EXPAND_BF16_TYPE(unsigned long long, ull)

#undef EXPAND_BF16_TYPE
#undef EXPAND_BF16_ROUNDING_MODE

// Mapping from BF16 to float is 1:1, lossless, so we accept all
// rounding modes.
template <typename NativeToT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeToT, float>, NativeToT>
ConvertFromBF16Scalar(uint16_t val) {
bfloat16 bfval = sycl::bit_cast<bfloat16>(val);
return static_cast<float>(bfval);
}

template <typename NativeFromT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeFromT, double>, uint16_t>
ConvertToBF16Scalar(NativeFromT val) {
#if defined(__SPIR__) || defined(__SPIRV__)
return __imf_double2bfloat16(val);
#else
constexpr int rm = static_cast<int>(RoundingMode);
bfloat16 bfval =
sycl::ext::oneapi::detail::ConvertToBfloat16::getBfloat16WithRoundingMode<
NativeFromT, rm>(val);
return sycl::bit_cast<uint16_t>(bfval);
#endif
}

template <typename NativeFromT, sycl::rounding_mode RoundingMode>
std::enable_if_t<std::is_same_v<NativeFromT, float>, uint16_t>
ConvertToBF16Scalar(NativeFromT val) {

#if defined(__SPIR__) || defined(__SPIRV__)
if constexpr (RoundingMode == sycl::rounding_mode::automatic ||
RoundingMode == sycl::rounding_mode::rte)
return __imf_float2bfloat16_rn(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtp)
return __imf_float2bfloat16_ru(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtn)
return __imf_float2bfloat16_rd(val);
else if constexpr (RoundingMode == sycl::rounding_mode::rtz)
return __imf_float2bfloat16_rz(val);
else
static_assert(false, "Invalid rounding mode.");
#else
constexpr int rm = static_cast<int>(RoundingMode);
bfloat16 bfval =
sycl::ext::oneapi::detail::ConvertToBfloat16::getBfloat16WithRoundingMode<
float, rm>(val);
return sycl::bit_cast<uint16_t>(bfval);
#endif
}

#endif // __SYCL_DEVICE_ONLY__

// Wrapper function for scalar and vector conversions from BF16 type.
template <typename ToT, typename NativeFromT, typename NativeToT,
sycl::rounding_mode RoundingMode, int VecSize>
NativeToT ConvertFromBF16(NativeFromT val) {
#ifdef __SYCL_DEVICE_ONLY__
// Use vector conversion from BF16 to float for all rounding modes.
if constexpr (std::is_same_v<ToT, float> && VecSize > 1)
return ConvertBF16ToFVec<NativeFromT, NativeToT, VecSize>(val);
else
#endif
// For VecSize > 1. Only for device.
if constexpr (VecSize > 1) {
NativeToT retval;
for (int i = 0; i < VecSize; i++) {
retval[i] = ConvertFromBF16Scalar<ToT, RoundingMode>(val[i]);
}
return retval;
}
// For VecSize == 1.
else
return ConvertFromBF16Scalar<NativeToT, RoundingMode>(val);
}

// Wrapper function for scalar and vector conversions to BF16 type.
template <typename FromT, typename NativeFromT, typename NativeToT,
sycl::rounding_mode RoundingMode, int VecSize>
NativeToT ConvertToBF16(NativeFromT val) {
#ifdef __SYCL_DEVICE_ONLY__
// Use vector conversion to BF16 from float for RNE rounding mode.
if constexpr (std::is_same_v<FromT, float> && VecSize > 1 &&
(RoundingMode == sycl::rounding_mode::automatic ||
RoundingMode == sycl::rounding_mode::rte))
return ConvertFToBF16Vec<NativeFromT, NativeToT, VecSize>(val);
else
#endif
// For VecSize > 1. Only for device.
if constexpr (VecSize > 1) {
NativeToT retval;
for (int i = 0; i < VecSize; i++) {
retval[i] = ConvertToBF16Scalar<FromT, RoundingMode>(val[i]);
}
return retval;
}
// For VecSize == 1.
else
return ConvertToBF16Scalar<NativeFromT, RoundingMode>(val);
}

/// Entry point helper for all kinds of converts between scalars and vectors, it
/// dispatches to a right function depending on source and destination types.
///
Expand Down Expand Up @@ -537,6 +832,14 @@ NativeToT convertImpl(NativeFromT Value) {
else if constexpr (is_float_to_float<FromT, ToT>::value)
return FConvert<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
// BF16 conversion to other types.
else if constexpr (std::is_same_v<FromT, bfloat16>)
return ConvertFromBF16<ToT, NativeFromT, NativeToT, RoundingMode, VecSize>(
Value);
// conversion from other types to BF16.
else if constexpr (std::is_same_v<ToT, bfloat16>)
return ConvertToBF16<FromT, NativeFromT, NativeToT, RoundingMode, VecSize>(
Value);
else if constexpr (is_float_to_sint<FromT, ToT>::value)
return ConvertFToS<NativeFromT, NativeToT, VecSize, ElemTy, RoundingMode>(
Value);
Expand Down
Loading

0 comments on commit 02c6bba

Please sign in to comment.