Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…trunc @round

and expand @sqrt

This revealed that the accuracy of ln is not as good as the current algorithm in
musl and glibc, and should be ported again.

v2: actually include tests
v3: fix reversal of in and out arguments on f128M_sqrt()
    add test for @sqrt on comptime_float
    do not include @nearbyInt() until it works on all targets.
  • Loading branch information
shawnl committed Jun 22, 2019
1 parent ebde2ff commit 71e014c
Show file tree
Hide file tree
Showing 11 changed files with 719 additions and 131 deletions.
85 changes: 83 additions & 2 deletions doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -7354,10 +7354,91 @@ test "@setRuntimeSafety" {
<pre>{#syntax#}@sqrt(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Performs the square root of a floating point number. Uses a dedicated hardware instruction
when available. Currently only supports f32 and f64 at runtime. f128 at runtime is TODO.
when available. Supports f16, f32, f64, and f128, as well as vectors.
</p>
{#header_close#}
{#header_open|@sin#}
<pre>{#syntax#}@sin(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Sine trigometric function on a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@cos#}
<pre>{#syntax#}@cos(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Cosine trigometric function on a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@exp#}
<pre>{#syntax#}@exp(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Base-e exponential function on a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@exp2#}
<pre>{#syntax#}@exp2(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Base-2 exponential function on a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@ln#}
<pre>{#syntax#}@ln(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the natural logarithm of a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@log2#}
<pre>{#syntax#}@log2(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the logarithm to the base 2 of a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@log10#}
<pre>{#syntax#}@log10(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the logarithm to the base 10 of a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@fabs#}
<pre>{#syntax#}@fabs(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the absolute value of a floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@floor#}
<pre>{#syntax#}@floor(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the largest integral value not greater than the given floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@ceil#}
<pre>{#syntax#}@ceil(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Returns the largest integral value not less than the given floating point number. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@trunc#}
<pre>{#syntax#}@trunc(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
Rounds the given floating point number to an integer, towards zero. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}
{#header_open|@round#}
<pre>{#syntax#}@round(comptime T: type, value: T) T{#endsyntax#}</pre>
<p>
This is a low-level intrinsic. Most code can use {#syntax#}std.math.sqrt{#endsyntax#} instead.
Rounds the given floating point number to an integer, away from zero. Uses a dedicated hardware instruction
when available. Currently supports f32 and f64.
</p>
{#header_close#}

Expand Down
26 changes: 20 additions & 6 deletions src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,19 @@ enum BuiltinFnId {
BuiltinFnIdRem,
BuiltinFnIdMod,
BuiltinFnIdSqrt,
BuiltinFnIdSin,
BuiltinFnIdCos,
BuiltinFnIdExp,
BuiltinFnIdExp2,
BuiltinFnIdLn,
BuiltinFnIdLog2,
BuiltinFnIdLog10,
BuiltinFnIdFabs,
BuiltinFnIdFloor,
BuiltinFnIdCeil,
BuiltinFnIdTrunc,
BuiltinFnIdNearbyInt,
BuiltinFnIdRound,
BuiltinFnIdTruncate,
BuiltinFnIdIntCast,
BuiltinFnIdFloatCast,
Expand Down Expand Up @@ -1556,9 +1569,7 @@ enum ZigLLVMFnId {
ZigLLVMFnIdPopCount,
ZigLLVMFnIdOverflowArithmetic,
ZigLLVMFnIdFMA,
ZigLLVMFnIdFloor,
ZigLLVMFnIdCeil,
ZigLLVMFnIdSqrt,
ZigLLVMFnIdFloatOp,
ZigLLVMFnIdBswap,
ZigLLVMFnIdBitReverse,
};
Expand All @@ -1585,6 +1596,7 @@ struct ZigLLVMFnKey {
uint32_t bit_count;
} pop_count;
struct {
BuiltinFnId op;
uint32_t bit_count;
uint32_t vector_len; // 0 means not a vector
} floating;
Expand Down Expand Up @@ -2239,6 +2251,7 @@ enum IrInstructionId {
IrInstructionIdAlignOf,
IrInstructionIdOverflowOp,
IrInstructionIdMulAdd,
IrInstructionIdFloatOp,
IrInstructionIdTestErr,
IrInstructionIdUnwrapErrCode,
IrInstructionIdUnwrapErrPayload,
Expand Down Expand Up @@ -2300,7 +2313,6 @@ enum IrInstructionId {
IrInstructionIdAddImplicitReturnType,
IrInstructionIdMergeErrRetTraces,
IrInstructionIdMarkErrRetTracePtr,
IrInstructionIdSqrt,
IrInstructionIdErrSetCast,
IrInstructionIdToBytes,
IrInstructionIdFromBytes,
Expand Down Expand Up @@ -3474,11 +3486,13 @@ struct IrInstructionMarkErrRetTracePtr {
IrInstruction *err_ret_trace_ptr;
};

struct IrInstructionSqrt {
// For float ops which take a single argument
struct IrInstructionFloatOp {
IrInstruction base;

BuiltinFnId op;
IrInstruction *type;
IrInstruction *op;
IrInstruction *op1;
};

struct IrInstructionCheckRuntimeScope {
Expand Down
15 changes: 8 additions & 7 deletions src/analyze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5736,9 +5736,10 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817;
case ZigLLVMFnIdPopCount:
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049;
case ZigLLVMFnIdFloor:
case ZigLLVMFnIdCeil:
case ZigLLVMFnIdSqrt:
case ZigLLVMFnIdFloatOp:
return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) +
(uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025) +
(uint32_t)(x.data.floating.op) * (uint32_t)43789879;
case ZigLLVMFnIdFMA:
return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) +
(uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025);
Expand Down Expand Up @@ -5769,10 +5770,10 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
return a.data.bswap.bit_count == b.data.bswap.bit_count;
case ZigLLVMFnIdBitReverse:
return a.data.bit_reverse.bit_count == b.data.bit_reverse.bit_count;
case ZigLLVMFnIdFloor:
case ZigLLVMFnIdCeil:
case ZigLLVMFnIdSqrt:
return a.data.floating.bit_count == b.data.floating.bit_count;
case ZigLLVMFnIdFloatOp:
return a.data.floating.bit_count == b.data.floating.bit_count &&
a.data.floating.vector_len == b.data.floating.vector_len &&
a.data.floating.op == b.data.floating.op;
case ZigLLVMFnIdFMA:
return a.data.floating.bit_count == b.data.floating.bit_count &&
a.data.floating.vector_len == b.data.floating.vector_len;
Expand Down
68 changes: 36 additions & 32 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu
return fn_val;
}

static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) {
static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id, BuiltinFnId op) {
assert(type_entry->id == ZigTypeIdFloat ||
type_entry->id == ZigTypeIdVector);

Expand All @@ -817,25 +817,20 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn
key.id = fn_id;
key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count;
key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0;
key.data.floating.op = op;

auto existing_entry = g->llvm_fn_table.maybe_get(key);
if (existing_entry)
return existing_entry->value;

const char *name;
uint32_t num_args;
if (fn_id == ZigLLVMFnIdFloor) {
name = "floor";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdCeil) {
name = "ceil";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdSqrt) {
name = "sqrt";
num_args = 1;
} else if (fn_id == ZigLLVMFnIdFMA) {
if (fn_id == ZigLLVMFnIdFMA) {
name = "fma";
num_args = 3;
} else if (fn_id == ZigLLVMFnIdFloatOp) {
name = float_op_to_name(op, true);
num_args = 1;
} else {
zig_unreachable();
}
Expand Down Expand Up @@ -2480,22 +2475,17 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry,
return result;
}

static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, ZigType *type_entry) {
if (type_entry->id == ZigTypeIdInt)
static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) {
if ((op == BuiltinFnIdCeil ||
op == BuiltinFnIdFloor) &&
type_entry->id == ZigTypeIdInt)
return val;
assert(type_entry->id == ZigTypeIdFloat);

LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor);
LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op);
return LLVMBuildCall(g->builder, floor_fn, &val, 1, "");
}

static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, ZigType *type_entry) {
if (type_entry->id == ZigTypeIdInt)
return val;

LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil);
return LLVMBuildCall(g->builder, ceil_fn, &val, 1, "");
}

enum DivKind {
DivKindFloat,
DivKindTrunc,
Expand Down Expand Up @@ -2571,7 +2561,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
return result;
case DivKindExact:
if (want_runtime_safety) {
LLVMValueRef floored = gen_floor(g, result, type_entry);
LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
Expand All @@ -2593,12 +2583,12 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);

LLVMPositionBuilderAtEnd(g->builder, ltz_block);
LLVMValueRef ceiled = gen_ceil(g, result, type_entry);
LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil);
LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);

LLVMPositionBuilderAtEnd(g->builder, gez_block);
LLVMValueRef floored = gen_floor(g, result, type_entry);
LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder);
LLVMBuildBr(g->builder, end_block);

Expand All @@ -2610,7 +2600,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
return phi;
}
case DivKindFloor:
return gen_floor(g, result, type_entry);
return gen_float_op(g, result, type_entry, BuiltinFnIdFloor);
}
zig_unreachable();
}
Expand Down Expand Up @@ -5450,10 +5440,10 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e
return nullptr;
}

static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) {
LLVMValueRef op = ir_llvm_value(g, instruction->op);
static LLVMValueRef ir_render_float_op(CodeGen *g, IrExecutable *executable, IrInstructionFloatOp *instruction) {
LLVMValueRef op = ir_llvm_value(g, instruction->op1);
assert(instruction->base.value.type->id == ZigTypeIdFloat);
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt);
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFloatOp, instruction->op);
return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
}

Expand All @@ -5463,7 +5453,7 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrIn
LLVMValueRef op3 = ir_llvm_value(g, instruction->op3);
assert(instruction->base.value.type->id == ZigTypeIdFloat ||
instruction->base.value.type->id == ZigTypeIdVector);
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA);
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd);
LLVMValueRef args[3] = {
op1,
op2,
Expand Down Expand Up @@ -5814,8 +5804,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction);
case IrInstructionIdMarkErrRetTracePtr:
return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction);
case IrInstructionIdSqrt:
return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction);
case IrInstructionIdFloatOp:
return ir_render_float_op(g, executable, (IrInstructionFloatOp *)instruction);
case IrInstructionIdMulAdd:
return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction);
case IrInstructionIdArrayToVector:
Expand Down Expand Up @@ -7435,6 +7425,20 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdRem, "rem", 2);
create_builtin_fn(g, BuiltinFnIdMod, "mod", 2);
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
create_builtin_fn(g, BuiltinFnIdSin, "sin", 2);
create_builtin_fn(g, BuiltinFnIdCos, "cos", 2);
create_builtin_fn(g, BuiltinFnIdExp, "exp", 2);
create_builtin_fn(g, BuiltinFnIdExp2, "exp2", 2);
create_builtin_fn(g, BuiltinFnIdLn, "ln", 2);
create_builtin_fn(g, BuiltinFnIdLog2, "log2", 2);
create_builtin_fn(g, BuiltinFnIdLog10, "log10", 2);
create_builtin_fn(g, BuiltinFnIdFabs, "fabs", 2);
create_builtin_fn(g, BuiltinFnIdFloor, "floor", 2);
create_builtin_fn(g, BuiltinFnIdCeil, "ceil", 2);
create_builtin_fn(g, BuiltinFnIdTrunc, "trunc", 2);
//Needs library support on Windows
//create_builtin_fn(g, BuiltinFnIdNearbyInt, "nearbyInt", 2);
create_builtin_fn(g, BuiltinFnIdRound, "round", 2);
create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4);
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
Expand Down
Loading

0 comments on commit 71e014c

Please sign in to comment.