Skip to content

Commit

Permalink
Emit aliases to FP16 conversion routines
Browse files Browse the repository at this point in the history
  • Loading branch information
KristofferC committed Dec 21, 2022
1 parent 21a2c43 commit 93915f8
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 36 deletions.
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
33 changes: 33 additions & 0 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include <llvm/Support/CodeGen.h>
#endif

#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManagers.h>
#include <llvm/Transforms/Utils/Cloning.h>

Expand Down Expand Up @@ -276,6 +277,24 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
*ci_out = codeinst;
}

static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
if (!target) {
target = Function::Create(FT, Function::ExternalLinkage, alias, M);
}
// Weak so that this does not get discarded
// maybe use llvm.compiler.used instead?
Function *interposer = Function::Create(FT, Function::WeakAnyLinkage, name, M);

llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", interposer));
SmallVector<Value *, 4> CallArgs;
for (auto &arg : interposer->args())
CallArgs.push_back(&arg);
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}

// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup, and can
// also be used be extern consumers like GPUCompiler.jl to obtain a module containing
Expand Down Expand Up @@ -554,6 +573,20 @@ void jl_dump_native(void *native_code,
"jl_RTLD_DEFAULT_handle_pointer"));
}

// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime functions.
injectCRTAlias(*data->M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(*data->M, "__extendhfsf2", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(*data->M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(*data->M, "__truncsfhf2", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(*data->M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));

// do the actual work
auto add_output = [&] (Module &M, StringRef unopt_bc_Name, StringRef bc_Name, StringRef obj_Name, StringRef asm_Name) {
PM.run(M);
Expand Down
11 changes: 3 additions & 8 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,22 +1476,17 @@ static inline uint16_t float_to_half(float param)

#if !defined(_OS_DARWIN_) // xcode already links compiler-rt

extern "C" JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param)
extern "C" JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param)
{
return half_to_float(param);
}

extern "C" JL_DLLEXPORT float __extendhfsf2(uint16_t param)
{
return half_to_float(param);
}

extern "C" JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param)
extern "C" JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param)
{
return float_to_half(param);
}

extern "C" JL_DLLEXPORT uint16_t __truncdfhf2(double param)
extern "C" JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
{
return float_to_half((float)param);
}
Expand Down
18 changes: 16 additions & 2 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,26 @@ JuliaOJIT::JuliaOJIT(TargetMachine &TM, LLVMContext *LLVMCtx)
}

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
orc::SymbolStringPtr JuliaOJIT::mangle(StringRef Name)
{
std::string MangleName = getMangledName(Name);
cantFail(JD.define(orc::absoluteSymbols({{ES.intern(MangleName), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
return ES.intern(MangleName);
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
{
cantFail(JD.define(orc::absoluteSymbols({{mangle(Name), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
}

void JuliaOJIT::addModule(std::unique_ptr<Module> M)
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class JuliaOJIT {
const object::ObjectFile &Obj,
const RuntimeDyld::LoadedObjectInfo &LoadedObjectInfo);
#endif
orc::SymbolStringPtr mangle(StringRef Name);
void addGlobalMapping(StringRef Name, uint64_t Addr);
void addModule(std::unique_ptr<Module> M);
#if JL_LLVM_VERSION < 120000
Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
5 changes: 3 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1363,8 +1363,9 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
30 changes: 15 additions & 15 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT
}

#define fp_select(a, func) \
sizeof(a) == sizeof(float) ? func##f((float)a) : func(a)
sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a)
#define fp_select2(a, b, func) \
sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b)
sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b)

// fast-function generators //

Expand Down Expand Up @@ -215,11 +215,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = __gnu_h2f_ieee(a); \
float A = julia__gnu_h2f_ieee(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
} else { \
OP((uint16_t*)pr, A); \
} \
Expand All @@ -243,11 +243,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}

// float or integer inputs, bool output
Expand All @@ -268,8 +268,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
return OP(A, B); \
}
Expand Down Expand Up @@ -309,12 +309,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float C = __gnu_h2f_ieee(c); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
float C = julia__gnu_h2f_ieee(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}


Expand Down Expand Up @@ -832,7 +832,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \
fpiseq_n(float, 32)
fpiseq_n(double, 64)
#define fpiseq(a,b) \
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)

#define fpislt_n(c_type, nbits) \
static inline int fpislt##nbits(c_type a, c_type b) JL_NOTSAFEPOINT \
Expand Down Expand Up @@ -903,7 +903,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui)
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) \
*(uint16_t*)pr = __gnu_f2h_ieee(a); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(a); \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
Expand Down
24 changes: 24 additions & 0 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,27 @@ end
@test_intrinsic Core.Intrinsics.fptosi Int Float16(3.3) 3
@test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3)
end

if Sys.ARCH == :aarch64
# On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`.
# TODO: SHould we have `Chalf == Int16` and `Cfloat16 == Float16`?
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, Int16, (Float32,), x))
gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, Int16, (Float32,), x))
truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, Int16, (Float64,), x))
else
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x)
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x)
truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x)
gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x)
truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x)
end

@testset "Float16 intrinsics (crt)" begin
@test extendhfsf2(Float16(3.3)) == 3.3007812f0
@test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0
@test truncsfhf2(3.3f0) == Float16(3.3)
@test gnu_f2h_ieee(3.3f0) == Float16(3.3)
@test truncdfhf2(3.3) == Float16(3.3)
end

0 comments on commit 93915f8

Please sign in to comment.