diff --git a/base/boot.jl b/base/boot.jl index 637b16e04c13e..7f7f4cf02422d 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -217,6 +217,8 @@ primitive type Float16 <: AbstractFloat 16 end primitive type Float32 <: AbstractFloat 32 end primitive type Float64 <: AbstractFloat 64 end +primitive type BFloat16 <: AbstractFloat 16 end + #primitive type Bool <: Integer 8 end abstract type AbstractChar end primitive type Char <: AbstractChar 32 end diff --git a/doc/src/base/reflection.md b/doc/src/base/reflection.md index b6246c06472a4..2798cfe2e7530 100644 --- a/doc/src/base/reflection.md +++ b/doc/src/base/reflection.md @@ -52,8 +52,9 @@ the abstract `DataType` [`AbstractFloat`](@ref) has four (concrete) subtypes: ```jldoctest; setup = :(using InteractiveUtils) julia> subtypes(AbstractFloat) -4-element Vector{Any}: +5-element Vector{Any}: BigFloat + Core.BFloat16 Float16 Float32 Float64 diff --git a/src/abi_x86_64.cpp b/src/abi_x86_64.cpp index c3d12417e6de8..7800c44b4d3ae 100644 --- a/src/abi_x86_64.cpp +++ b/src/abi_x86_64.cpp @@ -118,7 +118,7 @@ struct Classification { void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const { // Floating point types - if (dt == jl_float64_type || dt == jl_float32_type) { + if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) { accum.addField(offset, Sse); } // Misc types @@ -239,7 +239,9 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const types[0] = Type::getIntNTy(ctx, nbits); break; case Sse: - if (size <= 4) + if (size <= 2) + types[0] = Type::getHalfTy(ctx); + else if (size <= 4) types[0] = Type::getFloatTy(ctx); else types[0] = Type::getDoubleTy(ctx); diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index 3a54e2729ff5f..e3417a4c0dca1 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -497,7 +497,6 @@ static void reportWriterError(const ErrorInfoBase &E) jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str()); } -#if JULIA_FLOAT16_ABI == 1 static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT) { Function *target = M.getFunction(alias); @@ -514,7 +513,7 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT auto val = builder.CreateCall(target, CallArgs); builder.CreateRet(val); } -#endif + void multiversioning_preannotate(Module &M); // See src/processor.h for documentation about this table. Corresponds to jl_image_shard_t. @@ -1061,6 +1060,11 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer #else emitFloat16Wrappers(M, false); #endif + + injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2", + FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false)); + injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2", + FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false)); } timers.optimize.stopTimer(); } diff --git a/src/ccall.cpp b/src/ccall.cpp index 118803cef1b10..3c42e46d273cf 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -1123,22 +1123,21 @@ std::string generate_func_sig(const char *fname) isboxed = false; } else { - if (jl_is_primitivetype(tti)) { + t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall); + if (t == getVoidTy(LLVMCtx)) { + return make_errmsg(fname, i + 1, " type doesn't correspond to a C type"); + } + if (jl_is_primitivetype(tti) && t->isIntegerTy()) { // see pull req #978. need to annotate signext/zeroext for // small integer arguments. jl_datatype_t *bt = (jl_datatype_t*)tti; - if (jl_datatype_size(bt) < 4 && bt != jl_float16_type) { + if (jl_datatype_size(bt) < 4) { if (jl_signed_type && jl_subtype(tti, (jl_value_t*)jl_signed_type)) ab.addAttribute(Attribute::SExt); else ab.addAttribute(Attribute::ZExt); } } - - t = _julia_struct_to_llvm(ctx, LLVMCtx, tti, &isboxed, llvmcall); - if (t == getVoidTy(LLVMCtx)) { - return make_errmsg(fname, i + 1, " type doesn't correspond to a C type"); - } } Type *pat; diff --git a/src/cgutils.cpp b/src/cgutils.cpp index 7dfa509357e5a..91be89ddbe395 100644 --- a/src/cgutils.cpp +++ b/src/cgutils.cpp @@ -665,6 +665,8 @@ static Type *bitstype_to_llvm(jl_value_t *bt, LLVMContext &ctxt, bool llvmcall = return getFloatTy(ctxt); if (bt == (jl_value_t*)jl_float64_type) return getDoubleTy(ctxt); + if (bt == (jl_value_t*)jl_bfloat16_type) + return getBFloatTy(ctxt); if (jl_is_llvmpointer_type(bt)) { jl_value_t *as_param = jl_tparam1(bt); int as; diff --git a/src/codegen.cpp b/src/codegen.cpp index b6d18b23c930e..20f2dfe28165f 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -125,6 +125,9 @@ auto getFloatTy(LLVMContext &ctxt) { auto getDoubleTy(LLVMContext &ctxt) { return Type::getDoubleTy(ctxt); } +auto getBFloatTy(LLVMContext &ctxt) { + return Type::getBFloatTy(ctxt); +} auto getFP128Ty(LLVMContext &ctxt) { return Type::getFP128Ty(ctxt); } diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index 3e7ace18a1749..1bb68674990b7 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -165,7 +165,7 @@ static Type *INTT(Type *t, const DataLayout &DL) return getInt64Ty(ctxt); if (t == getFloatTy(ctxt)) return getInt32Ty(ctxt); - if (t == getHalfTy(ctxt)) + if (t == getHalfTy(ctxt) || t == getBFloatTy(ctxt)) return getInt16Ty(ctxt); unsigned nb = t->getPrimitiveSizeInBits(); assert(t != getVoidTy(ctxt) && nb > 0); diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index f0360c6addc95..6c356759cc066 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -1727,16 +1727,18 @@ JuliaOJIT::JuliaOJIT() ExternalJD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); -#if JULIA_FLOAT16_ABI == 1 orc::SymbolAliasMap jl_crt = { +#if JULIA_FLOAT16_ABI == 1 { mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, { mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, - { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } } + { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }, +#endif + { mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } }, + { mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } }, }; cantFail(GlobalJD.define(orc::symbolAliases(jl_crt))); -#endif #ifdef MSAN_EMUTLS_WORKAROUND orc::SymbolMap msan_crt; diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 2acde218a104c..aa23b9d7b8205 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -42,6 +42,7 @@ XX(jl_float16_type) \ XX(jl_float32_type) \ XX(jl_float64_type) \ + XX(jl_bfloat16_type) \ XX(jl_floatingpoint_type) \ XX(jl_function_type) \ XX(jl_binding_type) \ diff --git a/src/jltypes.c b/src/jltypes.c index 998f3fe47f157..33b52158488a3 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3403,6 +3403,8 @@ void post_boot_hooks(void) //XX(float32); jl_float64_type = (jl_datatype_t*)core("Float64"); //XX(float64); + jl_bfloat16_type = (jl_datatype_t*)core("BFloat16"); + //XX(bfloat16); jl_floatingpoint_type = (jl_datatype_t*)core("AbstractFloat"); jl_number_type = (jl_datatype_t*)core("Number"); jl_signed_type = (jl_datatype_t*)core("Signed"); diff --git a/src/julia.h b/src/julia.h index 07f8459d37238..a357bdf558360 100644 --- a/src/julia.h +++ b/src/julia.h @@ -848,6 +848,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_uint64_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float16_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float32_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_float64_type JL_GLOBALLY_ROOTED; +extern JL_DLLIMPORT jl_datatype_t *jl_bfloat16_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_floatingpoint_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_number_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_void_type JL_GLOBALLY_ROOTED; // deprecated diff --git a/src/julia_internal.h b/src/julia_internal.h index 41f976b8585f3..9dff8e75cb2f5 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1663,6 +1663,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT; //JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT; //JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT; //JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT; diff --git a/src/llvm-demote-float16.cpp b/src/llvm-demote-float16.cpp index 740055730fb90..5d0d9f5d37c40 100644 --- a/src/llvm-demote-float16.cpp +++ b/src/llvm-demote-float16.cpp @@ -1,8 +1,9 @@ // This file is a part of Julia. License is MIT: https://julialang.org/license -// This pass finds floating-point operations on 16-bit (half precision) values, and replaces -// them by equivalent operations on 32-bit (single precision) values surrounded by a fpext -// and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved. +// This pass finds floating-point operations on 16-bit values (half precision and bfloat), +// and replaces them by equivalent operations on 32-bit (single precision) values surrounded +// by a fpext and fptrunc. This ensures that the exact semantics of IEEE floating-point are +// preserved. // // Without this pass, back-ends that do not natively support half-precision (e.g. x86_64) // similarly pattern-match half-precision operations with single-precision equivalents, but @@ -71,10 +72,22 @@ static bool have_fp16(Function &caller, const Triple &TT) { return false; } +static bool have_bf16(Function &caller, const Triple &TT) { + if (caller.hasFnAttribute("julia.hasbf16")) { + return true; + } + + // there's no targets that fully support bfloat yet;, + // AVX512BF16 only provides conversion and dot product instructions. + return false; +} + static bool demoteFloat16(Function &F) { auto TT = Triple(F.getParent()->getTargetTriple()); - if (have_fp16(F, TT)) + auto has_fp16 = have_fp16(F, TT); + auto has_bf16 = have_bf16(F, TT); + if (has_fp16 && has_bf16) return false; auto &ctx = F.getContext(); @@ -82,14 +95,17 @@ static bool demoteFloat16(Function &F) SmallVector erase; for (auto &BB : F) { for (auto &I : BB) { - // extend Float16 operands to Float32 + // check whether there's any 16-bit floating point operands to extend bool Float16 = I.getType()->getScalarType()->isHalfTy(); - for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) { + bool BFloat16 = I.getType()->getScalarType()->isBFloatTy(); + for (size_t i = 0; !BFloat16 && !Float16 && i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) + if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) Float16 = true; + else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) + BFloat16 = true; } - if (!Float16) + if (!Float16 && !BFloat16) continue; switch (I.getOpcode()) { @@ -113,11 +129,16 @@ static bool demoteFloat16(Function &F) IRBuilder<> builder(&I); - // extend Float16 operands to Float32 + // extend 16-bit floating point operands SmallVector Operands(I.getNumOperands()); for (size_t i = 0; i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) { + if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) { + // extend Float16 to Float32 + ++TotalExt; + Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32)); + } else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) { + // extend BFloat16 to Float32 ++TotalExt; Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32)); } @@ -125,7 +146,7 @@ static bool demoteFloat16(Function &F) } // recreate the instruction if any operands changed, - // truncating the result back to Float16 + // truncating the result back to the original type Value *NewI; ++TotalChanged; switch (I.getOpcode()) { diff --git a/src/llvm-multiversioning.cpp b/src/llvm-multiversioning.cpp index da24882d85d6f..22f956294ddd3 100644 --- a/src/llvm-multiversioning.cpp +++ b/src/llvm-multiversioning.cpp @@ -50,7 +50,7 @@ extern Optional always_have_fma(Function&, const Triple &TT); namespace { constexpr uint32_t clone_mask = - JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16; + JL_TARGET_CLONE_LOOP | JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16; // Treat identical mapping as missing and return `def` in that case. // We mainly need this to identify cloned function using value map after LLVM cloning @@ -126,12 +126,14 @@ static uint32_t collect_func_info(Function &F, const Triple &TT, bool &has_vecca } for (size_t i = 0; i < I.getNumOperands(); i++) { - if(I.getOperand(i)->getType()->isHalfTy()){ + if(I.getOperand(i)->getType()->isHalfTy()) { flag |= JL_TARGET_CLONE_FLOAT16; } - // Check for BFloat16 when they are added to julia can be done here + if(I.getOperand(i)->getType()->isBFloatTy()) { + flag |= JL_TARGET_CLONE_BFLOAT16; + } } - uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16; + uint32_t veccall_flags = JL_TARGET_CLONE_SIMD | JL_TARGET_CLONE_MATH | JL_TARGET_CLONE_CPU | JL_TARGET_CLONE_FLOAT16 | JL_TARGET_CLONE_BFLOAT16; if (has_veccall && (flag & veccall_flags) == veccall_flags) { return flag; } diff --git a/src/processor.h b/src/processor.h index a3ebdf4f8c605..696d725ed826b 100644 --- a/src/processor.h +++ b/src/processor.h @@ -41,6 +41,8 @@ enum { JL_TARGET_CLONE_CPU = 1 << 8, // Clone when the function uses fp16 JL_TARGET_CLONE_FLOAT16 = 1 << 9, + // Clone when the function uses bf16 + JL_TARGET_CLONE_BFLOAT16 = 1 << 10, }; #define JL_FEATURE_DEF_NAME(name, bit, llvmver, str) JL_FEATURE_DEF(name, bit, llvmver) diff --git a/src/processor_x86.cpp b/src/processor_x86.cpp index 73e0992bcf37c..13dabd4e42db7 100644 --- a/src/processor_x86.cpp +++ b/src/processor_x86.cpp @@ -961,6 +961,13 @@ static void ensure_jit_target(bool imaging) break; } } + static constexpr uint32_t clone_bf16[] = {Feature::avx512bf16}; + for (auto fe: clone_bf16) { + if (!test_nbit(features0, fe) && test_nbit(t.en.features, fe)) { + t.en.flags |= JL_TARGET_CLONE_BFLOAT16; + break; + } + } } } diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index ed320aa9a6c35..b42b7d9832383 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -217,6 +217,44 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) return float_to_half(res); } +JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT +{ + uint16_t result; + + if (isnan(param)) + result = 0x7fc0; + else { + uint32_t bits = *((uint32_t*) ¶m); + + // round to nearest even + bits += 0x7fff + ((bits >> 16) & 1); + result = (uint16_t)(bits >> 16); + } + + // on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI + // support in the form of the __bf16 type; older versions only provide __bfloat16 which + // is simply a typedef for short (i16). so use float, which is passed in XMM too. + uint32_t result_32bit = (uint32_t)result; + return *(float*)&result_32bit; +} + +JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT +{ + float res = (float)param; + uint32_t resi; + memcpy(&resi, &res, sizeof(res)); + + // bfloat16 uses the same exponent as float32, so we don't need special handling + // for subnormals when truncating float64 to bfloat16. + + if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values + // adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer + resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res)); + memcpy(&res, &resi, sizeof(res)); + } + return julia__truncsfbf2(res); +} + //JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); } //JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); } //JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); } diff --git a/src/staticdata.c b/src/staticdata.c index 536ca4cd6c3aa..df5652a5719c4 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -99,7 +99,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 159 +#define NUM_TAGS 160 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. @@ -194,6 +194,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_float16_type); INSERT_TAG(jl_float32_type); INSERT_TAG(jl_float64_type); + INSERT_TAG(jl_bfloat16_type); INSERT_TAG(jl_floatingpoint_type); INSERT_TAG(jl_number_type); INSERT_TAG(jl_signed_type); diff --git a/test/llvmpasses/float16.ll b/test/llvmpasses/float16.ll index b442a39b0050c..0c37be449d959 100644 --- a/test/llvmpasses/float16.ll +++ b/test/llvmpasses/float16.ll @@ -3,9 +3,9 @@ ; RUN: opt -enable-new-pm=1 --opaque-pointers=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='DemoteFloat16' -S %s | FileCheck %s -define half @demotehalf_test(half %a, half %b) #0 { +define half @demote_half_test(half %a, half %b) #0 { top: -; CHECK-LABEL: @demotehalf_test( +; CHECK-LABEL: @demote_half_test( ; CHECK-NEXT: top: ; CHECK-NEXT: %0 = fpext half %a to float ; CHECK-NEXT: %1 = fpext half %b to float @@ -101,5 +101,66 @@ top: ret half %13 } +define bfloat @demote_bfloat_test(bfloat %a, bfloat %b) { +top: +; CHECK-LABEL: @demote_bfloat_test( +; CHECK-NEXT: top: +; CHECK-NEXT: %0 = fpext bfloat %a to float +; CHECK-NEXT: %1 = fpext bfloat %b to float +; CHECK-NEXT: %2 = fadd float %0, %1 +; CHECK-NEXT: %3 = fptrunc float %2 to bfloat +; CHECK-NEXT: %4 = fpext bfloat %3 to float +; CHECK-NEXT: %5 = fpext bfloat %b to float +; CHECK-NEXT: %6 = fadd float %4, %5 +; CHECK-NEXT: %7 = fptrunc float %6 to bfloat +; CHECK-NEXT: %8 = fpext bfloat %7 to float +; CHECK-NEXT: %9 = fpext bfloat %b to float +; CHECK-NEXT: %10 = fadd float %8, %9 +; CHECK-NEXT: %11 = fptrunc float %10 to bfloat +; CHECK-NEXT: %12 = fpext bfloat %11 to float +; CHECK-NEXT: %13 = fpext bfloat %b to float +; CHECK-NEXT: %14 = fmul float %12, %13 +; CHECK-NEXT: %15 = fptrunc float %14 to bfloat +; CHECK-NEXT: %16 = fpext bfloat %15 to float +; CHECK-NEXT: %17 = fpext bfloat %b to float +; CHECK-NEXT: %18 = fdiv float %16, %17 +; CHECK-NEXT: %19 = fptrunc float %18 to bfloat +; CHECK-NEXT: %20 = insertelement <2 x bfloat> undef, bfloat %a, i32 0 +; CHECK-NEXT: %21 = insertelement <2 x bfloat> %20, bfloat %b, i32 1 +; CHECK-NEXT: %22 = insertelement <2 x bfloat> undef, bfloat %b, i32 0 +; CHECK-NEXT: %23 = insertelement <2 x bfloat> %22, bfloat %b, i32 1 +; CHECK-NEXT: %24 = fpext <2 x bfloat> %21 to <2 x float> +; CHECK-NEXT: %25 = fpext <2 x bfloat> %23 to <2 x float> +; CHECK-NEXT: %26 = fadd <2 x float> %24, %25 +; CHECK-NEXT: %27 = fptrunc <2 x float> %26 to <2 x bfloat> +; CHECK-NEXT: %28 = extractelement <2 x bfloat> %27, i32 0 +; CHECK-NEXT: %29 = extractelement <2 x bfloat> %27, i32 1 +; CHECK-NEXT: %30 = fpext bfloat %28 to float +; CHECK-NEXT: %31 = fpext bfloat %29 to float +; CHECK-NEXT: %32 = fadd float %30, %31 +; CHECK-NEXT: %33 = fptrunc float %32 to bfloat +; CHECK-NEXT: %34 = fpext bfloat %33 to float +; CHECK-NEXT: %35 = fpext bfloat %19 to float +; CHECK-NEXT: %36 = fadd float %34, %35 +; CHECK-NEXT: %37 = fptrunc float %36 to bfloat +; CHECK-NEXT: ret bfloat %37 +; + %0 = fadd bfloat %a, %b + %1 = fadd bfloat %0, %b + %2 = fadd bfloat %1, %b + %3 = fmul bfloat %2, %b + %4 = fdiv bfloat %3, %b + %5 = insertelement <2 x bfloat> undef, bfloat %a, i32 0 + %6 = insertelement <2 x bfloat> %5, bfloat %b, i32 1 + %7 = insertelement <2 x bfloat> undef, bfloat %b, i32 0 + %8 = insertelement <2 x bfloat> %7, bfloat %b, i32 1 + %9 = fadd <2 x bfloat> %6, %8 + %10 = extractelement <2 x bfloat> %9, i32 0 + %11 = extractelement <2 x bfloat> %9, i32 1 + %12 = fadd bfloat %10, %11 + %13 = fadd bfloat %12, %4 + ret bfloat %13 +} + attributes #0 = { "target-features"="-avx512fp16" } attributes #1 = { "target-features"="+avx512fp16" } diff --git a/test/numbers.jl b/test/numbers.jl index be661da6783fe..a9d126aa33d5a 100644 --- a/test/numbers.jl +++ b/test/numbers.jl @@ -2901,6 +2901,7 @@ end let float_types = Set() allsubtypes!(Base, AbstractFloat, float_types) allsubtypes!(Core, AbstractFloat, float_types) + filter!(!isequal(Core.BFloat16), float_types) # defined externally @test !isempty(float_types) for T in float_types