Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BFloat16 runtime intrinsics. #51790

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 107 additions & 26 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ static inline uint16_t double_to_bfloat(double param) JL_NOTSAFEPOINT
return float_to_bfloat(temp);
}

static inline float bfloat_to_float(uint16_t param) JL_NOTSAFEPOINT
{
uint32_t bits = ((uint32_t)param) << 16;
float result;
memcpy(&result, &bits, sizeof(result));
return result;
}

// bfloat16 conversion API

// starting with GCC 13 and Clang 17, we have __bf16 on most platforms
Expand Down Expand Up @@ -726,25 +734,39 @@ static inline unsigned jl_##name##nbits(unsigned runtime_nbits, void *pa) JL_NOT
// nbits::number of bits in the *input*
// c_type::c_type corresponding to nbits
#define un_fintrinsic_ctype(OP, name, c_type) \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
c_type a = *(c_type*)pa; \
OP((c_type*)pr, a); \
OP(ty, (c_type*)pr, a); \
}

#define un_fintrinsic_half(OP, name) \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = half_to_float(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_half(R); \
} else { \
OP((uint16_t*)pr, A); \
OP(ty, (uint16_t*)pr, A); \
} \
}
}

#define un_fintrinsic_bfloat(OP, name) \
static inline void name(unsigned osize, jl_value_t *ty, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = bfloat_to_float(a); \
if (osize == 16) { \
float R; \
OP(ty, &R, A); \
*(uint16_t*)pr = float_to_bfloat(R); \
} else { \
OP(ty, (uint16_t*)pr, A); \
} \
}

// float or integer inputs
// OP::Function macro(inputa, inputb)
Expand All @@ -769,6 +791,19 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = float_to_half(R); \
*(uint16_t*)pr = float_to_half(R); \
}

#define bi_intrinsic_bfloat(OP, name) \
static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = float_to_bfloat(R); \
}

// float or integer inputs, bool output
Expand All @@ -795,6 +830,17 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
return OP(A, B); \
}

#define bool_intrinsic_bfloat(OP, name) \
static int jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
runtime_nbits = 16; \
return OP(A, B); \
}


// integer inputs, with precondition test
// OP::Function macro(inputa, inputb)
Expand Down Expand Up @@ -836,6 +882,21 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = float_to_half(R); \
*(uint16_t*)pr = float_to_half(R); \
}

#define ter_intrinsic_bfloat(OP, name) \
static void jl_##name##bf16(unsigned runtime_nbits, void *pa, void *pb, void *pc, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = bfloat_to_float(a); \
float B = bfloat_to_float(b); \
float C = bfloat_to_float(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = float_to_bfloat(R); \
}


Expand Down Expand Up @@ -980,12 +1041,13 @@ static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const
// floating point

#define un_fintrinsic_withtype(OP, name) \
un_fintrinsic_bfloat(OP, jl_##name##bf16) \
un_fintrinsic_half(OP, jl_##name##16) \
un_fintrinsic_ctype(OP, jl_##name##32, float) \
un_fintrinsic_ctype(OP, jl_##name##64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *ty, jl_value_t *a) \
{ \
return jl_fintrinsic_1(ty, a, #name, jl_##name##16, jl_##name##32, jl_##name##64); \
return jl_fintrinsic_1(ty, a, #name, jl_##name##bf16, jl_##name##16, jl_##name##32, jl_##name##64); \
}

#define un_fintrinsic(OP, name) \
Expand All @@ -995,9 +1057,9 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a) \
return jl_##name##_withtype(jl_typeof(a), a); \
}

typedef void (fintrinsic_op1)(unsigned, void*, void*);
typedef void (fintrinsic_op1)(unsigned, jl_value_t*, void*, void*);

static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *bfloatop, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
{
jl_task_t *ct = jl_current_task;
if (!jl_is_primitivetype(jl_typeof(a)))
Expand All @@ -1011,13 +1073,16 @@ static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const c
switch (sz) {
/* choose the right size c-type operation based on the input */
case 2:
halfop(sz2 * host_char_bit, pa, pr);
if (jl_typeof(a) == (jl_value_t*)jl_float16_type)
halfop(sz2 * host_char_bit, ty, pa, pr);
else /*if (jl_typeof(a) == (jl_value_t*)jl_bfloat16_type)*/
bfloatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 4:
floatop(sz2 * host_char_bit, pa, pr);
floatop(sz2 * host_char_bit, ty, pa, pr);
break;
case 8:
doubleop(sz2 * host_char_bit, pa, pr);
doubleop(sz2 * host_char_bit, ty, pa, pr);
break;
default:
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name);
Expand Down Expand Up @@ -1189,6 +1254,7 @@ static inline jl_value_t *jl_intrinsiclambda_checkeddiv(jl_value_t *ty, void *pa
// floating point

#define bi_fintrinsic(OP, name) \
bi_intrinsic_bfloat(OP, name) \
bi_intrinsic_half(OP, name) \
bi_intrinsic_ctype(OP, name, 32, float) \
bi_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1206,7 +1272,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pr); \
if ((jl_datatype_t*)ty == jl_float16_type) \
jl_##name##16(16, pa, pb, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
jl_##name##bf16(16, pa, pb, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pr); \
Expand All @@ -1221,6 +1290,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
}

#define bool_fintrinsic(OP, name) \
bool_intrinsic_bfloat(OP, name) \
bool_intrinsic_half(OP, name) \
bool_intrinsic_ctype(OP, name, 32, float) \
bool_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1237,7 +1307,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
cmp = jl_##name##16(16, pa, pb); \
if ((jl_datatype_t*)ty == jl_float16_type) \
cmp = jl_##name##16(16, pa, pb); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
cmp = jl_##name##bf16(16, pa, pb); \
break; \
case 4: \
cmp = jl_##name##32(32, pa, pb); \
Expand All @@ -1252,6 +1325,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
}

#define ter_fintrinsic(OP, name) \
ter_intrinsic_bfloat(OP, name) \
ter_intrinsic_half(OP, name) \
ter_intrinsic_ctype(OP, name, 32, float) \
ter_intrinsic_ctype(OP, name, 64, double) \
Expand All @@ -1269,7 +1343,10 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pc, pr); \
if ((jl_datatype_t*)ty == jl_float16_type) \
jl_##name##16(16, pa, pb, pc, pr); \
else /*if ((jl_datatype_t*)ty == jl_bfloat16_type)*/ \
jl_##name##bf16(16, pa, pb, pc, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pc, pr); \
Expand All @@ -1285,7 +1362,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)

// arithmetic
#define neg(a) -a
#define neg_float(pr, a) *pr = -a
#define neg_float(ty, pr, a) *pr = -a
un_iintrinsic_fast(LLVMNeg, neg, neg_int, u)
#define add(a,b) a + b
bi_iintrinsic_fast(LLVMAdd, add, add_int, u)
Expand Down Expand Up @@ -1501,18 +1578,22 @@ cvt_iintrinsic(LLVMUItoFP, uitofp)
cvt_iintrinsic(LLVMFPtoSI, fptosi)
cvt_iintrinsic(LLVMFPtoUI, fptoui)

#define fptrunc(pr, a) \
#define fptrunc(tr, pr, a) \
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) \
*(uint16_t*)pr = float_to_half(a); \
else if (osize == 16) { \
if ((jl_datatype_t*)tr == jl_float16_type) \
*(uint16_t*)pr = float_to_half(a); \
else /*if ((jl_datatype_t*)tr == jl_bfloat16_type)*/ \
*(uint16_t*)pr = float_to_bfloat(a); \
} \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
*(double*)pr = a; \
else \
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
#define fpext(pr, a) \
#define fpext(tr, pr, a) \
if (!(osize >= 8 * sizeof(a))) \
jl_error("fpext: output bitsize must be >= input bitsize"); \
if (osize == 32) \
Expand Down Expand Up @@ -1569,12 +1650,12 @@ checked_iintrinsic_div(LLVMRem_uov, checked_urem_int, u)
#define flipsign(a, b) \
(b >= 0) ? a : -a
bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
#define abs_float(pr, a) *pr = fp_select(a, fabs)
#define ceil_float(pr, a) *pr = fp_select(a, ceil)
#define floor_float(pr, a) *pr = fp_select(a, floor)
#define trunc_float(pr, a) *pr = fp_select(a, trunc)
#define rint_float(pr, a) *pr = fp_select(a, rint)
#define sqrt_float(pr, a) *pr = fp_select(a, sqrt)
#define abs_float(ty, pr, a) *pr = fp_select(a, fabs)
#define ceil_float(ty, pr, a) *pr = fp_select(a, ceil)
#define floor_float(ty, pr, a) *pr = fp_select(a, floor)
#define trunc_float(ty, pr, a) *pr = fp_select(a, trunc)
#define rint_float(ty, pr, a) *pr = fp_select(a, rint)
#define sqrt_float(ty, pr, a) *pr = fp_select(a, sqrt)
#define copysign_float(a, b) fp_select2(a, b, copysign)

un_fintrinsic(abs_float,abs_float)
Expand Down