diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index 588c0359f70be..0c4021570ef21 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -328,6 +328,14 @@ static inline uint16_t double_to_bfloat(double param) JL_NOTSAFEPOINT return float_to_bfloat(temp); } +static inline float bfloat_to_float(uint16_t param) JL_NOTSAFEPOINT +{ + uint32_t bits = ((uint32_t)param) << 16; + float result; + memcpy(&result, &bits, sizeof(result)); + return result; +} + // bfloat16 conversion API // starting with GCC 13 and Clang 17, we have __bf16 on most platforms @@ -726,25 +734,39 @@ static inline unsigned jl_##name##nbits(unsigned runtime_nbits, void *pa) JL_NOT // nbits::number of bits in the *input* // c_type::c_type corresponding to nbits #define un_fintrinsic_ctype(OP, name, c_type) \ -static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ +static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ c_type a = *(c_type*)pa; \ - OP((c_type*)pr, a); \ + OP(ty, (c_type*)pr, a); \ } #define un_fintrinsic_half(OP, name) \ -static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ +static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ uint16_t a = *(uint16_t*)pa; \ float A = half_to_float(a); \ if (osize == 16) { \ float R; \ - OP(&R, A); \ + OP(ty, &R, A); \ *(uint16_t*)pr = float_to_half(R); \ } else { \ - OP((uint16_t*)pr, A); \ + OP(ty, (uint16_t*)pr, A); \ } \ - } +} + +#define un_fintrinsic_bfloat(OP, name) \ +static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \ +{ \ + uint16_t a = *(uint16_t*)pa; \ + float A = bfloat_to_float(a); \ + if (osize == 16) { \ + float R; \ + OP(ty, &R, A); \ + *(uint16_t*)pr = float_to_bfloat(R); \ + } else { \ + OP(ty, (uint16_t*)pr, A); \ + } \ +} // float or integer inputs // OP::Function macro(inputa, inputb) @@ -769,6 +791,19 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) runtime_nbits = 16; \ float R = OP(A, B); \ *(uint16_t*)pr = float_to_half(R); \ + *(uint16_t*)pr = float_to_half(R); \ +} + +#define bi_intrinsic_bfloat(OP, name) \ +static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pr) JL_NOTSAFEPOINT \ +{ \ + uint16_t a = *(uint16_t*)pa; \ + uint16_t b = *(uint16_t*)pb; \ + float A = bfloat_to_float(a); \ + float B = bfloat_to_float(b); \ + runtime_nbits = 16; \ + float R = OP(A, B); \ + *(uint16_t*)pr = float_to_bfloat(R); \ } // float or integer inputs, bool output @@ -795,6 +830,17 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP return OP(A, B); \ } +#define bool_intrinsic_bfloat(OP, name) \ +static int jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEPOINT \ +{ \ + uint16_t a = *(uint16_t*)pa; \ + uint16_t b = *(uint16_t*)pb; \ + float A = bfloat_to_float(a); \ + float B = bfloat_to_float(b); \ + runtime_nbits = 16; \ + return OP(A, B); \ +} + // integer inputs, with precondition test // OP::Function macro(inputa, inputb) @@ -836,6 +882,21 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, runtime_nbits = 16; \ float R = OP(A, B, C); \ *(uint16_t*)pr = float_to_half(R); \ + *(uint16_t*)pr = float_to_half(R); \ +} + +#define ter_intrinsic_bfloat(OP, name) \ +static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pc, void *pr) JL_NOTSAFEPOINT \ +{ \ + uint16_t a = *(uint16_t*)pa; \ + uint16_t b = *(uint16_t*)pb; \ + uint16_t c = *(uint16_t*)pc; \ + float A = bfloat_to_float(a); \ + float B = bfloat_to_float(b); \ + float C = bfloat_to_float(c); \ + runtime_nbits = 16; \ + float R = OP(A, B, C); \ + *(uint16_t*)pr = float_to_bfloat(R); \ } @@ -980,12 +1041,13 @@ static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const // floating point #define un_fintrinsic_withtype(OP, name) \ +un_fintrinsic_bfloat(OP, jl_##name##bf16) \ un_fintrinsic_half(OP, jl_##name##16) \ un_fintrinsic_ctype(OP, jl_##name##32, float) \ un_fintrinsic_ctype(OP, jl_##name##64, double) \ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *ty, jl_value_t *a) \ { \ - return jl_fintrinsic_1(ty, a, #name, jl_##name##16, jl_##name##32, jl_##name##64); \ + return jl_fintrinsic_1(ty, a, #name, jl_##name##bf16, jl_##name##16, jl_##name##32, jl_##name##64); \ } #define un_fintrinsic(OP, name) \ @@ -995,9 +1057,9 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a) \ return jl_##name##_withtype(jl_typeof(a), a); \ } -typedef void (fintrinsic_op1)(unsigned, void*, void*); +typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*); -static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop) +static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *bfloatop, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop) { jl_task_t *ct = jl_current_task; if (!jl_is_primitivetype(jl_typeof(a))) @@ -1011,13 +1073,16 @@ static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const c switch (sz) { /* choose the right size c-type operation based on the input */ case 2: - halfop(sz2 * host_char_bit, pa, pr); + if (jl_typeof(a) == (jl_value_t*)jl_float16_type) + halfop(sz2 * host_char_bit, ty, pa, pr); + else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/ + bfloatop(sz2 * host_char_bit, ty, pa, pr); break; case 4: - floatop(sz2 * host_char_bit, pa, pr); + floatop(sz2 * host_char_bit, ty, pa, pr); break; case 8: - doubleop(sz2 * host_char_bit, pa, pr); + doubleop(sz2 * host_char_bit, ty, pa, pr); break; default: jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name); @@ -1189,6 +1254,7 @@ static inline jl_value_t *jl_intrinsiclambda_checkeddiv(jl_value_t *ty, void *pa // floating point #define bi_fintrinsic(OP, name) \ + bi_intrinsic_bfloat(OP, name) \ bi_intrinsic_half(OP, name) \ bi_intrinsic_ctype(OP, name, 32, float) \ bi_intrinsic_ctype(OP, name, 64, double) \ @@ -1206,7 +1272,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \ switch (sz) { \ /* choose the right size c-type operation */ \ case 2: \ - jl_##name##16(16, pa, pb, pr); \ + if ((jl_datatype_t*)ty == jl_float16_type) \ + jl_##name##16(16, pa, pb, pr); \ + else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \ + jl_##name##bf16(16, pa, pb, pr); \ break; \ case 4: \ jl_##name##32(32, pa, pb, pr); \ @@ -1221,6 +1290,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \ } #define bool_fintrinsic(OP, name) \ + bool_intrinsic_bfloat(OP, name) \ bool_intrinsic_half(OP, name) \ bool_intrinsic_ctype(OP, name, 32, float) \ bool_intrinsic_ctype(OP, name, 64, double) \ @@ -1237,7 +1307,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \ switch (sz) { \ /* choose the right size c-type operation */ \ case 2: \ - cmp = jl_##name##16(16, pa, pb); \ + if ((jl_datatype_t*)ty == jl_float16_type) \ + cmp = jl_##name##16(16, pa, pb); \ + else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \ + cmp = jl_##name##bf16(16, pa, pb); \ break; \ case 4: \ cmp = jl_##name##32(32, pa, pb); \ @@ -1252,6 +1325,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \ } #define ter_fintrinsic(OP, name) \ + ter_intrinsic_bfloat(OP, name) \ ter_intrinsic_half(OP, name) \ ter_intrinsic_ctype(OP, name, 32, float) \ ter_intrinsic_ctype(OP, name, 64, double) \ @@ -1269,7 +1343,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c) switch (sz) { \ /* choose the right size c-type operation */ \ case 2: \ - jl_##name##16(16, pa, pb, pc, pr); \ + if ((jl_datatype_t*)ty == jl_float16_type) \ + jl_##name##16(16, pa, pb, pc, pr); \ + else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \ + jl_##name##bf16(16, pa, pb, pc, pr); \ break; \ case 4: \ jl_##name##32(32, pa, pb, pc, pr); \ @@ -1285,7 +1362,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c) // arithmetic #define neg(a) -a -#define neg_float(pr, a) *pr = -a +#define neg_float(ty, pr, a) *pr = -a un_iintrinsic_fast(LLVMNeg, neg, neg_int, u) #define add(a,b) a + b bi_iintrinsic_fast(LLVMAdd, add, add_int, u) @@ -1501,18 +1578,22 @@ cvt_iintrinsic(LLVMUItoFP, uitofp) cvt_iintrinsic(LLVMFPtoSI, fptosi) cvt_iintrinsic(LLVMFPtoUI, fptoui) -#define fptrunc(pr, a) \ +#define fptrunc(tr, pr, a) \ if (!(osize < 8 * sizeof(a))) \ jl_error("fptrunc: output bitsize must be < input bitsize"); \ - else if (osize == 16) \ - *(uint16_t*)pr = float_to_half(a); \ + else if (osize == 16) { \ + if ((jl_datatype_t*)tr == jl_float16_type) \ + *(uint16_t*)pr = float_to_half(a); \ + else /*if ((jl_datatype_t*)tr == jl_bfloat16_type)*/ \ + *(uint16_t*)pr = float_to_bfloat(a); \ + } \ else if (osize == 32) \ *(float*)pr = a; \ else if (osize == 64) \ *(double*)pr = a; \ else \ jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); -#define fpext(pr, a) \ +#define fpext(tr, pr, a) \ if (!(osize >= 8 * sizeof(a))) \ jl_error("fpext: output bitsize must be >= input bitsize"); \ if (osize == 32) \ @@ -1569,12 +1650,12 @@ checked_iintrinsic_div(LLVMRem_uov, checked_urem_int, u) #define flipsign(a, b) \ (b >= 0) ? a : -a bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, ) -#define abs_float(pr, a) *pr = fp_select(a, fabs) -#define ceil_float(pr, a) *pr = fp_select(a, ceil) -#define floor_float(pr, a) *pr = fp_select(a, floor) -#define trunc_float(pr, a) *pr = fp_select(a, trunc) -#define rint_float(pr, a) *pr = fp_select(a, rint) -#define sqrt_float(pr, a) *pr = fp_select(a, sqrt) +#define abs_float(ty, pr, a) *pr = fp_select(a, fabs) +#define ceil_float(ty, pr, a) *pr = fp_select(a, ceil) +#define floor_float(ty, pr, a) *pr = fp_select(a, floor) +#define trunc_float(ty, pr, a) *pr = fp_select(a, trunc) +#define rint_float(ty, pr, a) *pr = fp_select(a, rint) +#define sqrt_float(ty, pr, a) *pr = fp_select(a, sqrt) #define copysign_float(a, b) fp_select2(a, b, copysign) un_fintrinsic(abs_float,abs_float)