Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add native support for BFloat16. #51470

Merged
merged 12 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end
primitive type Float32 <: AbstractFloat 32 end
primitive type Float64 <: AbstractFloat 64 end

primitive type BFloat16 <: AbstractFloat 16 end

#primitive type Bool <: Integer 8 end
abstract type AbstractChar end
primitive type Char <: AbstractChar 32 end
Expand Down
3 changes: 2 additions & 1 deletion doc/src/base/reflection.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ the abstract `DataType` [`AbstractFloat`](@ref) has four (concrete) subtypes:

```jldoctest; setup = :(using InteractiveUtils)
julia> subtypes(AbstractFloat)
4-element Vector{Any}:
5-element Vector{Any}:
BigFloat
Core.BFloat16
Float16
Float32
Float64
Expand Down
6 changes: 4 additions & 2 deletions src/abi_x86_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct Classification {
void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const
{
// Floating point types
if (dt == jl_float64_type || dt == jl_float32_type) {
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) {
accum.addField(offset, Sse);
}
// Misc types
Expand Down Expand Up @@ -239,7 +239,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const
types[0] = Type::getIntNTy(ctx, nbits);
break;
case Sse:
if (size <= 4)
if (size <= 2)
types[0] = Type::getHalfTy(ctx);
maleadt marked this conversation as resolved.
Show resolved Hide resolved
else if (size <= 4)
types[0] = Type::getFloatTy(ctx);
else
types[0] = Type::getDoubleTy(ctx);
Expand Down
8 changes: 6 additions & 2 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
Expand All @@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}
#endif

void multiversioning_preannotate(Module &M);

// See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t.
Expand Down Expand Up @@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
#else
emitFloat16Wrappers(M, false);
#endif

injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
}
timers.optimize.stopTimer();
}
Expand Down
13 changes: 6 additions & 7 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,22 +1123,21 @@ std::string generate_func_sig(const char *fname)
isboxed = false;
}
else {
if (jl_is_primitivetype(tti)) {
t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
if (t == getVoidTy(LLVMCtx)) {
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
}
if (jl_is_primitivetype(tti) && t->isIntegerTy()) {
// see pull req #978. need to annotate signext/zeroext for
// small integer arguments.
jl_datatype_t *bt = (jl_datatype_t*)tti;
if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) {
if (jl_datatype_size(bt) < 4) {
if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type))
ab.addAttribute(Attribute::SExt);
else
ab.addAttribute(Attribute::ZExt);
}
}

t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall);
if (t == getVoidTy(LLVMCtx)) {
return make_errmsg(fname, i + 1, " type doesn't correspond to a C type");
}
}

Type *pat;
Expand Down
2 changes: 2 additions & 0 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall =
return getFloatTy(ctxt);
if (bt == (jl_value_t*)jl_float64_type)
return getDoubleTy(ctxt);
if (bt == (jl_value_t*)jl_bfloat16_type)
return getBFloatTy(ctxt);
if (jl_is_llvmpointer_type(bt)) {
jl_value_t *as_param = jl_tparam1(bt);
int as;
Expand Down
3 changes: 3 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) {
auto getDoubleTy(LLVMContext &ctxt) {
return Type::getDoubleTy(ctxt);
}
auto getBFloatTy(LLVMContext &ctxt) {
return Type::getBFloatTy(ctxt);
}
auto getFP128Ty(LLVMContext &ctxt) {
return Type::getFP128Ty(ctxt);
}
Expand Down
2 changes: 1 addition & 1 deletion src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL)
return getInt64Ty(ctxt);
if (t == getFloatTy(ctxt))
return getInt32Ty(ctxt);
if (t == getHalfTy(ctxt))
if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt))
return getInt16Ty(ctxt);
unsigned nb = t->getPrimitiveSizeInBits();
assert(t != getVoidTy(ctxt) && nb > 0);
Expand Down
8 changes: 5 additions & 3 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT()
ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
#if JULIA_FLOAT16_ABI == 1
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } },
#endif
{ mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } },
{ mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } },
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
#endif

#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
XX(jl_float16_type) \
XX(jl_float32_type) \
XX(jl_float64_type) \
XX(jl_bfloat16_type) \
XX(jl_floatingpoint_type) \
XX(jl_function_type) \
XX(jl_binding_type) \
Expand Down
2 changes: 2 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3403,6 +3403,8 @@ void post_boot_hooks(void)
//XX(float32);
jl_float64_type = (jl_datatype_t*)core("Float64");
//XX(float64);
jl_bfloat16_type = (jl_datatype_t*)core("BFloat16");
//XX(bfloat16);
jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat");
jl_number_type = (jl_datatype_t*)core("Number");
jl_signed_type = (jl_datatype_t*)core("Signed");
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_uint64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float32_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_float64_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_bfloat16_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_floatingpoint_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_number_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_void_type JL_GLOBALLY_ROOTED; // deprecated
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
Expand Down
43 changes: 32 additions & 11 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

// This pass finds floating-point operations on 16-bit (half precision) values, and replaces
// them by equivalent operations on 32-bit (single precision) values surrounded by a fpext
// and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved.
// This pass finds floating-point operations on 16-bit values (half precision and bfloat),
// and replaces them by equivalent operations on 32-bit (single precision) values surrounded
// by a fpext and fptrunc. This ensures that the exact semantics of IEEE floating-point are
// preserved.
//
// Without this pass, back-ends that do not natively support half-precision (e.g. x86_64)
// similarly pattern-match half-precision operations with single-precision equivalents, but
Expand Down Expand Up @@ -71,25 +72,40 @@ static bool have_fp16(Function &caller, const Triple &TT) {
return false;
}

static bool have_bf16(Function &caller, const Triple &TT) {
if (caller.hasFnAttribute("julia.hasbf16")) {
return true;
}

// there's no targets that fully support bfloat yet;,
// AVX512BF16 only provides conversion and dot product instructions.
return false;
}

static bool demoteFloat16(Function &F)
{
auto TT = Triple(F.getParent()->getTargetTriple());
if (have_fp16(F, TT))
auto has_fp16 = have_fp16(F, TT);
auto has_bf16 = have_bf16(F, TT);
if (has_fp16 && has_bf16)
return false;

auto &ctx = F.getContext();
auto T_float32 = Type::getFloatTy(ctx);
SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
// check whether there's any 16-bit floating point operands to extend
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
bool BFloat16 = I.getType()->getScalarType()->isBFloatTy();
for (size_t i = 0; !BFloat16 && !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy())
BFloat16 = true;
}
if (!Float16)
if (!Float16 && !BFloat16)
continue;

switch (I.getOpcode()) {
Expand All @@ -113,19 +129,24 @@ static bool demoteFloat16(Function &F)

IRBuilder<> builder(&I);

// extend Float16 operands to Float32
// extend 16-bit floating point operands
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy()) {
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) {
// extend Float16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
} else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) {
// extend BFloat16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
}
Operands[i] = Op;
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
// truncating the result back to the original type
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
Expand Down
10 changes: 6 additions & 4 deletions src/llvm-multiversioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ extern Optional<bool> always_have_fma(Function&, const Triple &TT);

namespace {
constexpr uint32_t clone_mask =
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16;
JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16;

// Treat identical mapping as missing and return `def` in that case.
// We mainly need this to identify cloned function using value map after LLVM cloning
Expand Down Expand Up @@ -126,12 +126,14 @@ static uint32_t collect_func_info(Function &F, const Triple &TT, bool &has_vecca
}

for (size_t i = 0; i < I.getNumOperands(); i++) {
if(I.getOperand(i)->getType()->isHalfTy()){
if(I.getOperand(i)->getType()->isHalfTy()) {
flag |= JL_TARGET_CLONE_FLOAT16;
}
// Check for BFloat16 when they are added to julia can be done here
if(I.getOperand(i)->getType()->isBFloatTy()) {
flag |= JL_TARGET_CLONE_BFLOAT16;
}
}
uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16;
uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16;
if (has_veccall && (flag & veccall_flags) == veccall_flags) {
return flag;
}
Expand Down
2 changes: 2 additions & 0 deletions src/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ enum {
JL_TARGET_CLONE_CPU = 1 << 8,
// Clone when the function uses fp16
JL_TARGET_CLONE_FLOAT16 = 1 << 9,
// Clone when the function uses bf16
JL_TARGET_CLONE_BFLOAT16 = 1 << 10,
};

#define JL_FEATURE_DEF_NAME(name, bit, llvmver, str) JL_FEATURE_DEF(name, bit, llvmver)
Expand Down
7 changes: 7 additions & 0 deletions src/processor_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,13 @@ static void ensure_jit_target(bool imaging)
break;
}
}
static constexpr uint32_t clone_bf16[] = {Feature::avx512bf16};
for (auto fe: clone_bf16) {
if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) {
t.en.flags |= JL_TARGET_CLONE_BFLOAT16;
break;
}
}
}
}

Expand Down
38 changes: 38 additions & 0 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,44 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
return float_to_half(res);
}

JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT
{
uint16_t result;

if (isnan(param))
result = 0x7fc0;
else {
uint32_t bits = *((uint32_t*) &param);

// round to nearest even
bits += 0x7fff + ((bits >> 16) & 1);
result = (uint16_t)(bits >> 16);
}

// on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI
// support in the form of the __bf16 type; older versions only provide __bfloat16 which
// is simply a typedef for short (i16). so use float, which is passed in XMM too.
uint32_t result_32bit = (uint32_t)result;
return *(float*)&result_32bit;
}

JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT
{
float res = (float)param;
uint32_t resi;
memcpy(&resi, &res, sizeof(res));

// bfloat16 uses the same exponent as float32, so we don't need special handling
// for subnormals when truncating float64 to bfloat16.

if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values
// adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer
resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res));
memcpy(&res, &resi, sizeof(res));
}
return julia__truncsfbf2(res);
}

//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); }
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ extern "C" {
// TODO: put WeakRefs on the weak_refs list during deserialization
// TODO: handle finalizers

#define NUM_TAGS 159
#define NUM_TAGS 160

// An array of references that need to be restored from the sysimg
// This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C.
Expand Down Expand Up @@ -194,6 +194,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_float16_type);
INSERT_TAG(jl_float32_type);
INSERT_TAG(jl_float64_type);
INSERT_TAG(jl_bfloat16_type);
INSERT_TAG(jl_floatingpoint_type);
INSERT_TAG(jl_number_type);
INSERT_TAG(jl_signed_type);
Expand Down
Loading