Skip to content

Commit

Permalink
Implement Float16 runtime intrinsics.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 14, 2020
1 parent adb50bc commit 2c77780
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 17 deletions.
19 changes: 13 additions & 6 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "APInt-C.h"
#include "julia.h"
#include "julia_assert.h"
#include "julia_internal.h"

using namespace llvm;

Expand Down Expand Up @@ -312,14 +313,16 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
ASSIGN(r, a)
}

void LLVMFPtoInt(unsigned numbits, integerPart *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 32)
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Val = *(double*)pa;
else
jl_error("FPtoSI: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
jl_error("FPtoSI: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
unsigned onumbytes = RoundUpToAlignment(onumbits, host_char_bit) / host_char_bit;
if (onumbits <= 64) { // fast-path, if possible
if (isSigned) {
Expand Down Expand Up @@ -387,12 +390,14 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
CREATE(a)
val = a.roundToDouble(true);
}
if (onumbits == 32)
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
*(double*)pr = val;
else
jl_error("SItoFP: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
jl_error("SItoFP: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
}

extern "C" JL_DLLEXPORT
Expand All @@ -402,7 +407,9 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
CREATE(a)
val = a.roundToDouble(false);
}
if (onumbits == 32)
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
*(double*)pr = val;
Expand Down
3 changes: 3 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,9 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
#endif
Expand Down
88 changes: 79 additions & 9 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
OP((c_type*)pr, a); \
}

#define un_fintrinsic_half(OP, name) \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = __gnu_h2f_ieee(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
} else { \
OP((uint16_t*)pr, A); \
} \
}

// float or integer inputs
// OP::Function macro(inputa, inputb)
// name::unique string
Expand All @@ -224,6 +238,18 @@ static void jl_##name##nbits(unsigned runtime_nbits, void *pa, void *pb, void *p
*(c_type*)pr = (c_type)OP(a, b); \
}

#define bi_intrinsic_half(OP, name) \
static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
}

// float or integer inputs, bool output
// OP::Function macro(inputa, inputb)
// name::unique string
Expand All @@ -237,6 +263,18 @@ static int jl_##name##nbits(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSA
return OP(a, b); \
}

#define bool_intrinsic_half(OP, name) \
static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
runtime_nbits = 16; \
return OP(A, B); \
}


// integer inputs, with precondition test
// OP::Function macro(inputa, inputb)
// name::unique string
Expand Down Expand Up @@ -265,6 +303,20 @@ static void jl_##name##nbits(unsigned runtime_nbits, void *pa, void *pb, void *p
*(c_type*)pr = (c_type)OP(a, b, c); \
}

#define ter_intrinsic_half(OP, name) \
static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float C = __gnu_h2f_ieee(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
}


// unary operator generator //

Expand Down Expand Up @@ -407,11 +459,12 @@ static inline jl_value_t *jl_intrinsic_cvt(jl_value_t *ty, jl_value_t *a, const
// floating point

#define un_fintrinsic_withtype(OP, name) \
un_fintrinsic_half(OP, jl_##name##16) \
un_fintrinsic_ctype(OP, jl_##name##32, float) \
un_fintrinsic_ctype(OP, jl_##name##64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *ty, jl_value_t *a) \
{ \
return jl_fintrinsic_1(ty, a, #name, jl_##name##32, jl_##name##64); \
return jl_fintrinsic_1(ty, a, #name, jl_##name##16, jl_##name##32, jl_##name##64); \
}

#define un_fintrinsic(OP, name) \
Expand All @@ -423,7 +476,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a) \

typedef void (fintrinsic_op1)(unsigned, void*, void*);

static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const char *name, fintrinsic_op1 *halfop, fintrinsic_op1 *floatop, fintrinsic_op1 *doubleop)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (!jl_is_primitivetype(jl_typeof(a)))
Expand All @@ -436,14 +489,17 @@ static inline jl_value_t *jl_fintrinsic_1(jl_value_t *ty, jl_value_t *a, const c
unsigned sz = jl_datatype_size(jl_typeof(a));
switch (sz) {
/* choose the right size c-type operation based on the input */
case 2:
halfop(sz2 * host_char_bit, pa, pr);
break;
case 4:
floatop(sz2 * host_char_bit, pa, pr);
break;
case 8:
doubleop(sz2 * host_char_bit, pa, pr);
break;
default:
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64", name);
jl_errorf("%s: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64", name);
}
return newv;
}
Expand Down Expand Up @@ -612,6 +668,7 @@ static inline jl_value_t *jl_intrinsiclambda_checkeddiv(jl_value_t *ty, void *pa
// floating point

#define bi_fintrinsic(OP, name) \
bi_intrinsic_half(OP, name) \
bi_intrinsic_ctype(OP, name, 32, float) \
bi_intrinsic_ctype(OP, name, 64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
Expand All @@ -627,19 +684,23 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pr = jl_data_ptr(newv); \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pr); \
break; \
case 8: \
jl_##name##64(64, pa, pb, pr); \
break; \
default: \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
} \
return newv; \
}

#define bool_fintrinsic(OP, name) \
bool_intrinsic_half(OP, name) \
bool_intrinsic_ctype(OP, name, 32, float) \
bool_intrinsic_ctype(OP, name, 64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
Expand All @@ -654,6 +715,9 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
int cmp; \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
cmp = jl_##name##16(16, pa, pb); \
break; \
case 4: \
cmp = jl_##name##32(32, pa, pb); \
break; \
Expand All @@ -667,6 +731,7 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b) \
}

#define ter_fintrinsic(OP, name) \
ter_intrinsic_half(OP, name) \
ter_intrinsic_ctype(OP, name, 32, float) \
ter_intrinsic_ctype(OP, name, 64, double) \
JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c) \
Expand All @@ -682,14 +747,17 @@ JL_DLLEXPORT jl_value_t *jl_##name(jl_value_t *a, jl_value_t *b, jl_value_t *c)
void *pa = jl_data_ptr(a), *pb = jl_data_ptr(b), *pc = jl_data_ptr(c), *pr = jl_data_ptr(newv); \
switch (sz) { \
/* choose the right size c-type operation */ \
case 2: \
jl_##name##16(16, pa, pb, pc, pr); \
break; \
case 4: \
jl_##name##32(32, pa, pb, pc, pr); \
break; \
case 8: \
jl_##name##64(64, pa, pb, pc, pr); \
break; \
default: \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64"); \
jl_error(#name ": runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64"); \
} \
return newv; \
}
Expand Down Expand Up @@ -834,15 +902,17 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui)
#define fptrunc(pr, a) \
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
if (osize == 32) \
else if (osize == 16) \
*(uint16_t*)pr = __gnu_f2h_ieee(a); \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
*(double*)pr = a; \
else \
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
jl_error("fptrunc: runtime floating point intrinsics are not implemented for bit sizes other than 16, 32 and 64");
#define fpext(pr, a) \
if (!(osize > 8 * sizeof(a))) \
jl_error("fpext: output bitsize must be > input bitsize"); \
if (!(osize >= 8 * sizeof(a))) \
jl_error("fpext: output bitsize must be >= input bitsize"); \
if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
Expand Down
50 changes: 48 additions & 2 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ include("testenv.jl")
@testset "runtime intrinsics" begin
@test Core.Intrinsics.add_int(1, 1) == 2
@test Core.Intrinsics.sub_int(1, 1) == 0
@test_throws ErrorException("fpext: output bitsize must be > input bitsize") Core.Intrinsics.fpext(Int32, 0x0000_0000)
@test_throws ErrorException("fpext: output bitsize must be > input bitsize") Core.Intrinsics.fpext(Int32, 0x0000_0000_0000_0000)
@test_throws ErrorException("fpext: output bitsize must be >= input bitsize") Core.Intrinsics.fpext(Int32, 0x0000_0000_0000_0000)
@test_throws ErrorException("fptrunc: output bitsize must be < input bitsize") Core.Intrinsics.fptrunc(Int32, 0x0000_0000)
@test_throws ErrorException("fptrunc: output bitsize must be < input bitsize") Core.Intrinsics.fptrunc(Int64, 0x0000_0000)
@test_throws ErrorException("ZExt: output bitsize must be > input bitsize") Core.Intrinsics.zext_int(Int8, 0x00)
Expand Down Expand Up @@ -106,3 +105,50 @@ end
@test unsafe_load(Ptr{Nothing}(0)) === nothing
struct GhostStruct end
@test unsafe_load(Ptr{GhostStruct}(rand(Int))) === GhostStruct()

# macro to verify and compare the compiled output of an intrinsic with its runtime version
macro test_intrinsic(intr, args...)
output = args[end]
inputs = args[1:end-1]
quote
function f()
$intr($(inputs...))
end
@test f() === Base.invokelatest($intr, $(inputs...))
@test f() == $output
end
end

@testset "Float16 intrinsics" begin
# unary
@test_intrinsic Core.Intrinsics.neg_float Float16(3.3) Float16(-3.3)
@test_intrinsic Core.Intrinsics.fpext Float32 Float16(3.3) 3.3007812f0
@test_intrinsic Core.Intrinsics.fpext Float64 Float16(3.3) 3.30078125
@test_intrinsic Core.Intrinsics.fptrunc Float16 Float32(3.3) Float16(3.3)
@test_intrinsic Core.Intrinsics.fptrunc Float16 Float64(3.3) Float16(3.3)

# binary
@test_intrinsic Core.Intrinsics.add_float Float16(3.3) Float16(2) Float16(5.3)
@test_intrinsic Core.Intrinsics.sub_float Float16(3.3) Float16(2) Float16(1.301)
@test_intrinsic Core.Intrinsics.mul_float Float16(3.3) Float16(2) Float16(6.6)
@test_intrinsic Core.Intrinsics.div_float Float16(3.3) Float16(2) Float16(1.65)
@test_intrinsic Core.Intrinsics.rem_float Float16(3.3) Float16(2) Float16(1.301)

# ternary
@test_intrinsic Core.Intrinsics.fma_float Float16(3.3) Float16(4.4) Float16(5.5) Float16(20.02)
@test_intrinsic Core.Intrinsics.muladd_float Float16(3.3) Float16(4.4) Float16(5.5) Float16(20.02)

# boolean
@test_intrinsic Core.Intrinsics.eq_float Float16(3.3) Float16(3.3) true
@test_intrinsic Core.Intrinsics.eq_float Float16(3.3) Float16(2) false
@test_intrinsic Core.Intrinsics.ne_float Float16(3.3) Float16(3.3) false
@test_intrinsic Core.Intrinsics.ne_float Float16(3.3) Float16(2) true
@test_intrinsic Core.Intrinsics.le_float Float16(3.3) Float16(3.3) true
@test_intrinsic Core.Intrinsics.le_float Float16(3.3) Float16(2) false

# conversions
@test_intrinsic Core.Intrinsics.sitofp Float16 3 Float16(3f0)
@test_intrinsic Core.Intrinsics.uitofp Float16 UInt(3) Float16(3f0)
@test_intrinsic Core.Intrinsics.fptosi Int Float16(3.3) 3
@test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3)
end

0 comments on commit 2c77780

Please sign in to comment.