Skip to content

Commit

Permalink
Add native support for BFloat16.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 2, 2023
1 parent e9d633f commit d893ef1
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 10 deletions.
2 changes: 2 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end
primitive type Float32 <: AbstractFloat 32 end
primitive type Float64 <: AbstractFloat 64 end

primitive type BFloat16 <: AbstractFloat 16 end

#primitive type Bool <: Integer 8 end
abstract type AbstractChar end
primitive type Char <: AbstractChar 32 end
Expand Down
7 changes: 5 additions & 2 deletions src/abi_x86_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ struct Classification {
void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const
{
// Floating point types
if (dt == jl_float64_type || dt == jl_float32_type) {
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_float16_type ||
dt == jl_bfloat16_type) {
accum.addField(offset, Sse);
}
// Misc types
Expand Down Expand Up @@ -239,7 +240,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const
types[0] = Type::getIntNTy(ctx, nbits);
break;
case Sse:
if (size <= 4)
if (size <= 2)
types[0] = Type::getHalfTy(ctx);
else if (size <= 4)
types[0] = Type::getFloatTy(ctx);
else
types[0] = Type::getDoubleTy(ctx);
Expand Down
8 changes: 6 additions & 2 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
Expand All @@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}
#endif

void multiversioning_preannotate(Module &M);

// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
Expand Down Expand Up @@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
#else
emitFloat16Wrappers(M, false);
#endif

injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
}
timers.optimize.stopTimer();
}
Expand Down
2 changes: 1 addition & 1 deletion src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ std::string generate_func_sig(const char *fname)
// see pull req #978. need to annotate signext/zeroext for
// small integer arguments.
jl_datatype_t *bt = (jl_datatype_t*)tti;
if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) {
if (jl_datatype_size(bt) < 4) {
if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type))
ab.addAttribute(Attribute::SExt);
else
Expand Down
2 changes: 2 additions & 0 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall =
return getFloatTy(ctxt);
if (bt == (jl_value_t*)jl_float64_type)
return getDoubleTy(ctxt);
if (bt == (jl_value_t*)jl_bfloat16_type)
return getBFloatTy(ctxt);
if (jl_is_llvmpointer_type(bt)) {
jl_value_t *as_param = jl_tparam1(bt);
int as;
Expand Down
3 changes: 3 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) {
auto getDoubleTy(LLVMContext &ctxt) {
return Type::getDoubleTy(ctxt);
}
auto getBFloatTy(LLVMContext &ctxt) {
return Type::getBFloatTy(ctxt);
}
auto getFP128Ty(LLVMContext &ctxt) {
return Type::getFP128Ty(ctxt);
}
Expand Down
2 changes: 1 addition & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL)
return getInt64Ty(ctxt);
if (t == getFloatTy(ctxt))
return getInt32Ty(ctxt);
if (t == getHalfTy(ctxt))
if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt))
return getInt16Ty(ctxt);
unsigned nb = t->getPrimitiveSizeInBits();
assert(t != getVoidTy(ctxt) && nb > 0);
Expand Down
8 changes: 5 additions & 3 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT()
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
#if JULIA_FLOAT16_ABI == 1
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } },
#endif
{ mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } },
{ mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } },
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
#endif

#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
XX(jl_float16_type) \
XX(jl_float32_type) \
XX(jl_float64_type) \
XX(jl_bfloat16_type) \
XX(jl_floatingpoint_type) \
XX(jl_function_type) \
XX(jl_binding_type) \
Expand Down
2 changes: 2 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3403,6 +3403,8 @@ void post_boot_hooks(void)
//XX(float32);
jl_float64_type = (jl_datatype_t*)core("Float64");
//XX(float64);
jl_bfloat16_type = (jl_datatype_t*)core("BFloat16");
//XX(bfloat16);
jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat");
jl_number_type = (jl_datatype_t*)core("Number");
jl_signed_type = (jl_datatype_t*)core("Signed");
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_uint64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float32_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_bfloat16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_floatingpoint_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_number_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_void_type JL_GLOBALLY_ROOTED; // deprecated
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
Expand Down
45 changes: 45 additions & 0 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,51 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
return float_to_half(res);
}

JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT
{
uint16_t result;

if (isnan(param))
result = 0x7fc0;
else {
uint32_t bits = *((uint32_t*) &param);

// round to nearest even
uint32_t bit_above_round = (bits >> 17) & 1;
uint32_t round_bit = (bits >> 16) & 1;
uint32_t sticky_bit = (bits & 0xFFFF) != 0;
if (round_bit && (sticky_bit || bit_above_round))
bits += 0x10000; // Add 1 to bit just above the target bits

result = (uint16_t)(bits >> 16);
}

// on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI
// support in the form of the __bf16 type; older versions only provide __bfloat16 which
// is simply a typedef for short (i16). so use float, which is passed in XMM too.
uint32_t result_32bit = (uint32_t)result;
return *(float*)&result_32bit;
}

JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT
{
float res = (float)param;
uint32_t resi;
memcpy(&resi, &res, sizeof(res));

// Handle subnormals: If this logic is activated, it indicates that when we
// cast our double to a float, the float is a subnormal number. However,
// bfloat16 uses the same exponent as float32, so we don't need special handling
// for subnormals when truncating to bfloat16.

if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values
// adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer
resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res));
memcpy(&res, &resi, sizeof(res));
}
return julia__truncsfbf2(res);
}

//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); }
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ extern "C" {
// TODO: put WeakRefs on the weak_refs list during deserialization
// TODO: handle finalizers

#define NUM_TAGS 159
#define NUM_TAGS 160

// An array of references that need to be restored from the sysimg
// This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C.
Expand Down Expand Up @@ -194,6 +194,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_float16_type);
INSERT_TAG(jl_float32_type);
INSERT_TAG(jl_float64_type);
INSERT_TAG(jl_bfloat16_type);
INSERT_TAG(jl_floatingpoint_type);
INSERT_TAG(jl_number_type);
INSERT_TAG(jl_signed_type);
Expand Down

0 comments on commit d893ef1

Please sign in to comment.