From 93915f8b8d78c401a6e9426a36dac3bd5f88739a Mon Sep 17 00:00:00 2001 From: KristofferC Date: Wed, 21 Dec 2022 14:31:49 +0100 Subject: [PATCH] Emit aliases to FP16 conversion routines --- src/APInt-C.cpp | 6 +++--- src/aotcompile.cpp | 33 +++++++++++++++++++++++++++++++++ src/intrinsics.cpp | 11 +++-------- src/jitlayers.cpp | 18 ++++++++++++++++-- src/jitlayers.h | 1 + src/julia.expmap | 6 ------ src/julia_internal.h | 5 +++-- src/runtime_intrinsics.c | 30 +++++++++++++++--------------- test/intrinsics.jl | 24 ++++++++++++++++++++++++ 9 files changed, 98 insertions(+), 36 deletions(-) diff --git a/src/APInt-C.cpp b/src/APInt-C.cpp index bc0a62e21dd3e..f06d4362bf958 100644 --- a/src/APInt-C.cpp +++ b/src/APInt-C.cpp @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) { void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) { double Val; if (numbits == 16) - Val = __gnu_h2f_ieee(*(uint16_t*)pa); + Val = julia__gnu_h2f_ieee(*(uint16_t*)pa); else if (numbits == 32) Val = *(float*)pa; else if (numbits == 64) @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(true); } if (onumbits == 16) - *(uint16_t*)pr = __gnu_f2h_ieee(val); + *(uint16_t*)pr = julia__gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(false); } if (onumbits == 16) - *(uint16_t*)pr = __gnu_f2h_ieee(val); + *(uint16_t*)pr = julia__gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index b7bdac7923c2e..1dc6bd3c94877 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -51,6 +51,7 @@ #include #endif +#include #include #include @@ -276,6 +277,24 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance *ci_out = codeinst; } +static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT) +{ + Function *target = M.getFunction(alias); + if (!target) { + target = Function::Create(FT, Function::ExternalLinkage, alias, M); + } + // Weak so that this does not get discarded + // maybe use llvm.compiler.used instead? + Function *interposer = Function::Create(FT, Function::WeakAnyLinkage, name, M); + + llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", interposer)); + SmallVector CallArgs; + for (auto &arg : interposer->args()) + CallArgs.push_back(&arg); + auto val = builder.CreateCall(target, CallArgs); + builder.CreateRet(val); +} + // takes the running content that has collected in the shadow module and dump it to disk // this builds the object file portion of the sysimage files for fast startup, and can // also be used be extern consumers like GPUCompiler.jl to obtain a module containing @@ -554,6 +573,20 @@ void jl_dump_native(void *native_code, "jl_RTLD_DEFAULT_handle_pointer")); } + // We would like to emit an alias or an weakref alias to redirect these symbols + // but LLVM doesn't let us emit a GlobalAlias to a declaration... + // So for now we inject a definition of these functions that calls our runtime functions. + injectCRTAlias(*data->M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", + FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false)); + injectCRTAlias(*data->M, "__extendhfsf2", "julia__gnu_h2f_ieee", + FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false)); + injectCRTAlias(*data->M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", + FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false)); + injectCRTAlias(*data->M, "__truncsfhf2", "julia__gnu_f2h_ieee", + FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false)); + injectCRTAlias(*data->M, "__truncdfhf2", "julia__truncdfhf2", + FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false)); + // do the actual work auto add_output = [&] (Module &M, StringRef unopt_bc_Name, StringRef bc_Name, StringRef obj_Name, StringRef asm_Name) { PM.run(M); diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index 8620ce0c57dd8..f5a00bdf09cbc 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -1476,22 +1476,17 @@ static inline uint16_t float_to_half(float param) #if !defined(_OS_DARWIN_) // xcode already links compiler-rt -extern "C" JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param) +extern "C" JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) { return half_to_float(param); } -extern "C" JL_DLLEXPORT float __extendhfsf2(uint16_t param) -{ - return half_to_float(param); -} - -extern "C" JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param) +extern "C" JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) { return float_to_half(param); } -extern "C" JL_DLLEXPORT uint16_t __truncdfhf2(double param) +extern "C" JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) { return float_to_half((float)param); } diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index 74ed20ce11fb6..9b5f848575cd1 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -737,12 +737,26 @@ JuliaOJIT::JuliaOJIT(TargetMachine &TM, LLVMContext *LLVMCtx) } JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); + + orc::SymbolAliasMap jl_crt = { + { mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, + { mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, + { mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, + { mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, + { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } } + }; + cantFail(GlobalJD.define(orc::symbolAliases(jl_crt))); } -void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr) +orc::SymbolStringPtr JuliaOJIT::mangle(StringRef Name) { std::string MangleName = getMangledName(Name); - cantFail(JD.define(orc::absoluteSymbols({{ES.intern(MangleName), JITEvaluatedSymbol::fromPointer((void*)Addr)}}))); + return ES.intern(MangleName); +} + +void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr) +{ + cantFail(JD.define(orc::absoluteSymbols({{mangle(Name), JITEvaluatedSymbol::fromPointer((void*)Addr)}}))); } void JuliaOJIT::addModule(std::unique_ptr M) diff --git a/src/jitlayers.h b/src/jitlayers.h index 01fcf3c325df9..3c7e9ec113ab6 100644 --- a/src/jitlayers.h +++ b/src/jitlayers.h @@ -185,6 +185,7 @@ class JuliaOJIT { const object::ObjectFile &Obj, const RuntimeDyld::LoadedObjectInfo &LoadedObjectInfo); #endif + orc::SymbolStringPtr mangle(StringRef Name); void addGlobalMapping(StringRef Name, uint64_t Addr); void addModule(std::unique_ptr M); #if JL_LLVM_VERSION < 120000 diff --git a/src/julia.expmap b/src/julia.expmap index baf3220f147ed..aa77f4c8cff31 100644 --- a/src/julia.expmap +++ b/src/julia.expmap @@ -42,12 +42,6 @@ environ; __progname; - /* compiler run-time intrinsics */ - __gnu_h2f_ieee; - __extendhfsf2; - __gnu_f2h_ieee; - __truncdfhf2; - local: *; }; diff --git a/src/julia_internal.h b/src/julia_internal.h index 0b766967918e2..80fea5098b00b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1363,8 +1363,9 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; #define JL_GC_ASSERT_LIVE(x) (void)(x) #endif -float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; -uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; #ifdef __cplusplus } diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index 2337abe7d5704..497001866c1b5 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -169,9 +169,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT } #define fp_select(a, func) \ - sizeof(a) == sizeof(float) ? func##f((float)a) : func(a) + sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a) #define fp_select2(a, b, func) \ - sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b) + sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b) // fast-function generators // @@ -215,11 +215,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ uint16_t a = *(uint16_t*)pa; \ - float A = __gnu_h2f_ieee(a); \ + float A = julia__gnu_h2f_ieee(a); \ if (osize == 16) { \ float R; \ OP(&R, A); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } else { \ OP((uint16_t*)pr, A); \ } \ @@ -243,11 +243,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ runtime_nbits = 16; \ float R = OP(A, B); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } // float or integer inputs, bool output @@ -268,8 +268,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ runtime_nbits = 16; \ return OP(A, B); \ } @@ -309,12 +309,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ uint16_t c = *(uint16_t*)pc; \ - float A = __gnu_h2f_ieee(a); \ - float B = __gnu_h2f_ieee(b); \ - float C = __gnu_h2f_ieee(c); \ + float A = julia__gnu_h2f_ieee(a); \ + float B = julia__gnu_h2f_ieee(b); \ + float C = julia__gnu_h2f_ieee(c); \ runtime_nbits = 16; \ float R = OP(A, B, C); \ - *(uint16_t*)pr = __gnu_f2h_ieee(R); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ } @@ -832,7 +832,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \ fpiseq_n(float, 32) fpiseq_n(double, 64) #define fpiseq(a,b) \ - sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b) + sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b) #define fpislt_n(c_type, nbits) \ static inline int fpislt##nbits(c_type a, c_type b) JL_NOTSAFEPOINT \ @@ -903,7 +903,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui) if (!(osize < 8 * sizeof(a))) \ jl_error("fptrunc: output bitsize must be < input bitsize"); \ else if (osize == 16) \ - *(uint16_t*)pr = __gnu_f2h_ieee(a); \ + *(uint16_t*)pr = julia__gnu_f2h_ieee(a); \ else if (osize == 32) \ *(float*)pr = a; \ else if (osize == 64) \ diff --git a/test/intrinsics.jl b/test/intrinsics.jl index 47560d7dbd626..64ebec8b1fca4 100644 --- a/test/intrinsics.jl +++ b/test/intrinsics.jl @@ -152,3 +152,27 @@ end @test_intrinsic Core.Intrinsics.fptosi Int Float16(3.3) 3 @test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3) end + +if Sys.ARCH == :aarch64 + # On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`. + # TODO: SHould we have `Chalf == Int16` and `Cfloat16 == Float16`? + extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Int16,), reinterpret(Int16, x)) + gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Int16,), reinterpret(Int16, x)) + truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, Int16, (Float32,), x)) + gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, Int16, (Float32,), x)) + truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, Int16, (Float64,), x)) +else + extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x) + gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x) + truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x) + gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x) + truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x) +end + +@testset "Float16 intrinsics (crt)" begin + @test extendhfsf2(Float16(3.3)) == 3.3007812f0 + @test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0 + @test truncsfhf2(3.3f0) == Float16(3.3) + @test gnu_f2h_ieee(3.3f0) == Float16(3.3) + @test truncdfhf2(3.3) == Float16(3.3) +end