Skip to content

Commit

Permalink
Add BFloat16 runtime intrinsics. (#51790)
Browse files Browse the repository at this point in the history
After switching to LLVM for BFloat16 in #51470 (i.e., relying on
`Intrinsics.sub_float` etc instead of hand-rolling bit-twiddling
implementations), we also need to provide fallback runtime
implementations for these intrinsics. This is too bad; I had hoped to
put as much BFloat16-related things as possible in BFloat16s.jl.

This required modifying the unary operator preprocessor macros in order
to differentiate between Float16 and BFloat16; I didn't generalize that to
all intrinsics as the code is hairy enough already (and it's currently
only useful for fptrunc/fpext).
  • Loading branch information
maleadt authored Oct 25, 2023
1 parent bb138fa commit a1ccf53
Showing 1 changed file with 107 additions and 26 deletions.
133 changes: 107 additions & 26 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ static inline uint16_t double_to_bfloat(double param) JL_NOTSAFEPOINT
return float_to_bfloat(temp);
}

static inline float bfloat_to_float(uint16_t param) JL_NOTSAFEPOINT
{
uint32_t bits = ((uint32_t)param) << 16;
float result;
memcpy(&result, &bits, sizeof(result));
return result;
}

// bfloat16 conversion API

// starting with GCC 13 and Clang 17, we have __bf16 on most platforms
Expand Down Expand Up @@ -726,25 +734,39 @@ static inline unsigned jl_##name##nbits(unsigned runtime_nbits, void *pa) JL_NOT
// nbits::number of bits in the *input*
// c_type::c_type corresponding to nbits
#define un_fintrinsic_ctype(OP, name, c_type) \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
c_type a = *(c_type*)pa; \
OP((c_type*)pr, a); \
OP(ty, (c_type*)pr, a); \
}

#define un_fintrinsic_half(OP, name) \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = half_to_float(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_half(R); \
} else { \
OP((uint16_t*)pr, A); \
OP(ty, (uint16_t*)pr, A); \
} \
}
}

#define un_fintrinsic_bfloat(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = bfloat_to_float(a); \
if (osize == 16) { \
float R; \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_bfloat(R); \
} else { \
OP(ty, (uint16_t*)pr, A); \
} \
}

// float or integer inputs
// OP::Function macro(inputa, inputb)
Expand All @@ -769,6 +791,19 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = float_to_half(R); \
*(uint16_t*)pr = float_to_half(R); \
}

#define bi_intrinsic_bfloat(OP, name) \
static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = float_to_bfloat(R); \
}

// float or integer inputs, bool output
Expand All @@ -795,6 +830,17 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
return OP(A, B); \
}

#define bool_intrinsic_bfloat(OP, name) \
static int jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
runtime_nbits = 16; \
return OP(A, B); \
}


// integer inputs, with precondition test
// OP::Function macro(inputa, inputb)
Expand Down Expand Up @@ -836,6 +882,21 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = float_to_half(R); \
*(uint16_t*)pr = float_to_half(R); \
}

#define ter_intrinsic_bfloat(OP, name) \
static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pc, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
float C = bfloat_to_float(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = float_to_bfloat(R); \
}


Expand Down Expand Up @@ -980,12 +1041,13 @@ static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const
// floating point

#define un_fintrinsic_withtype(OP, name) \
un_fintrinsic_bfloat(OP, jl_##name##bf16) \
un_fintrinsic_half(OP, jl_##name##16) \
un_fintrinsic_ctype(OP, jl_##name##32, float) \
un_fintrinsic_ctype(OP, jl_##name##64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *ty, jl_value_t *a) \
{ \
return jl_fintrinsic_1(ty, a, #name, jl_##name##16, jl_##name##32, jl_##name##64); \
return jl_fintrinsic_1(ty, a, #name, jl_##name##bf16, jl_##name##16, jl_##name##32, jl_##name##64); \
}

#define un_fintrinsic(OP, name) \
Expand All @@ -995,9 +1057,9 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a) \
return jl_##name##_withtype(jl_typeof(a), a); \
}

typedef void (fintrinsic_op1)(unsigned, void*, void*);
typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);

static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *bfloatop, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
{
jl_task_t *ct = jl_current_task;
if (!jl_is_primitivetype(jl_typeof(a)))
Expand All @@ -1011,13 +1073,16 @@ static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const c
switch (sz) {
/* choose the right size c-type operation based on the input */
case 2:
halfop(sz2 * host_char_bit, pa, pr);
if (jl_typeof(a) == (jl_value_t*)jl_float16_type)
halfop(sz2 * host_char_bit, ty, pa, pr);
else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
bfloatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 4:
floatop(sz2 * host_char_bit, pa, pr);
floatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 8:
doubleop(sz2 * host_char_bit, pa, pr);
doubleop(sz2 * host_char_bit, ty, pa, pr);
break;
default:
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name);
Expand Down Expand Up @@ -1189,6 +1254,7 @@ static inline jl_value_t *jl_intrinsiclambda_checkeddiv(jl_value_t *ty, void *pa
// floating point

#define bi_fintrinsic(OP, name) \
bi_intrinsic_bfloat(OP, name) \
bi_intrinsic_half(OP, name) \
bi_intrinsic_ctype(OP, name, 32, float) \
bi_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1206,7 +1272,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pr); \
if ((jl_datatype_t*)ty == jl_float16_type) \
jl_##name##16(16, pa, pb, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
jl_##name##bf16(16, pa, pb, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pr); \
Expand All @@ -1221,6 +1290,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
}

#define bool_fintrinsic(OP, name) \
bool_intrinsic_bfloat(OP, name) \
bool_intrinsic_half(OP, name) \
bool_intrinsic_ctype(OP, name, 32, float) \
bool_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1237,7 +1307,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
cmp = jl_##name##16(16, pa, pb); \
if ((jl_datatype_t*)ty == jl_float16_type) \
cmp = jl_##name##16(16, pa, pb); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
cmp = jl_##name##bf16(16, pa, pb); \
break; \
case 4: \
cmp = jl_##name##32(32, pa, pb); \
Expand All @@ -1252,6 +1325,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
}

#define ter_fintrinsic(OP, name) \
ter_intrinsic_bfloat(OP, name) \
ter_intrinsic_half(OP, name) \
ter_intrinsic_ctype(OP, name, 32, float) \
ter_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1269,7 +1343,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pc, pr); \
if ((jl_datatype_t*)ty == jl_float16_type) \
jl_##name##16(16, pa, pb, pc, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
jl_##name##bf16(16, pa, pb, pc, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pc, pr); \
Expand All @@ -1285,7 +1362,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)

// arithmetic
#define neg(a) -a
#define neg_float(pr, a) *pr = -a
#define neg_float(ty, pr, a) *pr = -a
un_iintrinsic_fast(LLVMNeg, neg, neg_int, u)
#define add(a,b) a + b
bi_iintrinsic_fast(LLVMAdd, add, add_int, u)
Expand Down Expand Up @@ -1501,18 +1578,22 @@ cvt_iintrinsic(LLVMUItoFP, uitofp)
cvt_iintrinsic(LLVMFPtoSI, fptosi)
cvt_iintrinsic(LLVMFPtoUI, fptoui)

#define fptrunc(pr, a) \
#define fptrunc(tr, pr, a) \
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) \
*(uint16_t*)pr = float_to_half(a); \
else if (osize == 16) { \
if ((jl_datatype_t*)tr == jl_float16_type) \
*(uint16_t*)pr = float_to_half(a); \
else /*if ((jl_datatype_t*)tr == jl_bfloat16_type)*/ \
*(uint16_t*)pr = float_to_bfloat(a); \
} \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
*(double*)pr = a; \
else \
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
#define fpext(pr, a) \
#define fpext(tr, pr, a) \
if (!(osize >= 8 * sizeof(a))) \
jl_error("fpext: output bitsize must be >= input bitsize"); \
if (osize == 32) \
Expand Down Expand Up @@ -1569,12 +1650,12 @@ checked_iintrinsic_div(LLVMRem_uov, checked_urem_int, u)
#define flipsign(a, b) \
(b >= 0) ? a : -a
bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
#define abs_float(pr, a) *pr = fp_select(a, fabs)
#define ceil_float(pr, a) *pr = fp_select(a, ceil)
#define floor_float(pr, a) *pr = fp_select(a, floor)
#define trunc_float(pr, a) *pr = fp_select(a, trunc)
#define rint_float(pr, a) *pr = fp_select(a, rint)
#define sqrt_float(pr, a) *pr = fp_select(a, sqrt)
#define abs_float(ty, pr, a) *pr = fp_select(a, fabs)
#define ceil_float(ty, pr, a) *pr = fp_select(a, ceil)
#define floor_float(ty, pr, a) *pr = fp_select(a, floor)
#define trunc_float(ty, pr, a) *pr = fp_select(a, trunc)
#define rint_float(ty, pr, a) *pr = fp_select(a, rint)
#define sqrt_float(ty, pr, a) *pr = fp_select(a, sqrt)
#define copysign_float(a, b) fp_select2(a, b, copysign)

un_fintrinsic(abs_float,abs_float)
Expand Down

0 comments on commit a1ccf53

Please sign in to comment.