From 71e014caecaa54fdd8a0516710d2d9597da41398 Mon Sep 17 00:00:00 2001 From: Shawn Landden Date: Fri, 21 Jun 2019 16:18:59 -0500 Subject: [PATCH] stage1: add @sin @cos @exp @exp2 @ln @log2 @log10 @fabs @floor @ceil @trunc @round and expand @sqrt This revealed that the accuracy of ln is not as good as the current algorithm in musl and glibc, and should be ported again. v2: actually include tests v3: fix reversal of in and out arguments on f128M_sqrt() add test for @sqrt on comptime_float do not include @nearbyInt() until it works on all targets. --- doc/langref.html.in | 85 +++++++- src/all_types.hpp | 26 ++- src/analyze.cpp | 15 +- src/codegen.cpp | 68 +++--- src/ir.cpp | 355 +++++++++++++++++++++++++------ src/ir.hpp | 1 + src/ir_print.cpp | 11 +- src/util.cpp | 1 + std/special/c.zig | 44 ++-- test/stage1/behavior.zig | 1 + test/stage1/behavior/floatop.zig | 243 +++++++++++++++++++++ 11 files changed, 719 insertions(+), 131 deletions(-) create mode 100644 test/stage1/behavior/floatop.zig diff --git a/doc/langref.html.in b/doc/langref.html.in index 9b95946256da..30fe9a36485c 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -7354,10 +7354,91 @@ test "@setRuntimeSafety" {
{#syntax#}@sqrt(comptime T: type, value: T) T{#endsyntax#}

Performs the square root of a floating point number. Uses a dedicated hardware instruction - when available. Currently only supports f32 and f64 at runtime. f128 at runtime is TODO. + when available. Supports f16, f32, f64, and f128, as well as vectors.

+ {#header_close#} + {#header_open|@sin#} +
{#syntax#}@sin(comptime T: type, value: T) T{#endsyntax#}
+

+ Sine trigometric function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@cos#} +
{#syntax#}@cos(comptime T: type, value: T) T{#endsyntax#}
+

+ Cosine trigometric function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@exp#} +
{#syntax#}@exp(comptime T: type, value: T) T{#endsyntax#}
+

+ Base-e exponential function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@exp2#} +
{#syntax#}@exp2(comptime T: type, value: T) T{#endsyntax#}
+

+ Base-2 exponential function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@ln#} +
{#syntax#}@ln(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the natural logarithm of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@log2#} +
{#syntax#}@log2(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the logarithm to the base 2 of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@log10#} +
{#syntax#}@log10(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the logarithm to the base 10 of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@fabs#} +
{#syntax#}@fabs(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the absolute value of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@floor#} +
{#syntax#}@floor(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the largest integral value not greater than the given floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@ceil#} +
{#syntax#}@ceil(comptime T: type, value: T) T{#endsyntax#}
+

+ Returns the largest integral value not less than the given floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@trunc#} +
{#syntax#}@trunc(comptime T: type, value: T) T{#endsyntax#}
+

+ Rounds the given floating point number to an integer, towards zero. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +

+ {#header_close#} + {#header_open|@round#} +
{#syntax#}@round(comptime T: type, value: T) T{#endsyntax#}

- This is a low-level intrinsic. Most code can use {#syntax#}std.math.sqrt{#endsyntax#} instead. + Rounds the given floating point number to an integer, away from zero. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64.

{#header_close#} diff --git a/src/all_types.hpp b/src/all_types.hpp index 83df71b95f48..6595218bcf4e 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1434,6 +1434,19 @@ enum BuiltinFnId { BuiltinFnIdRem, BuiltinFnIdMod, BuiltinFnIdSqrt, + BuiltinFnIdSin, + BuiltinFnIdCos, + BuiltinFnIdExp, + BuiltinFnIdExp2, + BuiltinFnIdLn, + BuiltinFnIdLog2, + BuiltinFnIdLog10, + BuiltinFnIdFabs, + BuiltinFnIdFloor, + BuiltinFnIdCeil, + BuiltinFnIdTrunc, + BuiltinFnIdNearbyInt, + BuiltinFnIdRound, BuiltinFnIdTruncate, BuiltinFnIdIntCast, BuiltinFnIdFloatCast, @@ -1556,9 +1569,7 @@ enum ZigLLVMFnId { ZigLLVMFnIdPopCount, ZigLLVMFnIdOverflowArithmetic, ZigLLVMFnIdFMA, - ZigLLVMFnIdFloor, - ZigLLVMFnIdCeil, - ZigLLVMFnIdSqrt, + ZigLLVMFnIdFloatOp, ZigLLVMFnIdBswap, ZigLLVMFnIdBitReverse, }; @@ -1585,6 +1596,7 @@ struct ZigLLVMFnKey { uint32_t bit_count; } pop_count; struct { + BuiltinFnId op; uint32_t bit_count; uint32_t vector_len; // 0 means not a vector } floating; @@ -2239,6 +2251,7 @@ enum IrInstructionId { IrInstructionIdAlignOf, IrInstructionIdOverflowOp, IrInstructionIdMulAdd, + IrInstructionIdFloatOp, IrInstructionIdTestErr, IrInstructionIdUnwrapErrCode, IrInstructionIdUnwrapErrPayload, @@ -2300,7 +2313,6 @@ enum IrInstructionId { IrInstructionIdAddImplicitReturnType, IrInstructionIdMergeErrRetTraces, IrInstructionIdMarkErrRetTracePtr, - IrInstructionIdSqrt, IrInstructionIdErrSetCast, IrInstructionIdToBytes, IrInstructionIdFromBytes, @@ -3474,11 +3486,13 @@ struct IrInstructionMarkErrRetTracePtr { IrInstruction *err_ret_trace_ptr; }; -struct IrInstructionSqrt { +// For float ops which take a single argument +struct IrInstructionFloatOp { IrInstruction base; + BuiltinFnId op; IrInstruction *type; - IrInstruction *op; + IrInstruction *op1; }; struct IrInstructionCheckRuntimeScope { diff --git a/src/analyze.cpp b/src/analyze.cpp index 15b42c7f9dad..13b35e0aff43 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5736,9 +5736,10 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817; case ZigLLVMFnIdPopCount: return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049; - case ZigLLVMFnIdFloor: - case ZigLLVMFnIdCeil: - case ZigLLVMFnIdSqrt: + case ZigLLVMFnIdFloatOp: + return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) + + (uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025) + + (uint32_t)(x.data.floating.op) * (uint32_t)43789879; case ZigLLVMFnIdFMA: return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) + (uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025); @@ -5769,10 +5770,10 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { return a.data.bswap.bit_count == b.data.bswap.bit_count; case ZigLLVMFnIdBitReverse: return a.data.bit_reverse.bit_count == b.data.bit_reverse.bit_count; - case ZigLLVMFnIdFloor: - case ZigLLVMFnIdCeil: - case ZigLLVMFnIdSqrt: - return a.data.floating.bit_count == b.data.floating.bit_count; + case ZigLLVMFnIdFloatOp: + return a.data.floating.bit_count == b.data.floating.bit_count && + a.data.floating.vector_len == b.data.floating.vector_len && + a.data.floating.op == b.data.floating.op; case ZigLLVMFnIdFMA: return a.data.floating.bit_count == b.data.floating.bit_count && a.data.floating.vector_len == b.data.floating.vector_len; diff --git a/src/codegen.cpp b/src/codegen.cpp index 6691652a5e04..41caa29dbd55 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -806,7 +806,7 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu return fn_val; } -static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) { +static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id, BuiltinFnId op) { assert(type_entry->id == ZigTypeIdFloat || type_entry->id == ZigTypeIdVector); @@ -817,6 +817,7 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn key.id = fn_id; key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count; key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0; + key.data.floating.op = op; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) @@ -824,18 +825,12 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn const char *name; uint32_t num_args; - if (fn_id == ZigLLVMFnIdFloor) { - name = "floor"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdCeil) { - name = "ceil"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdSqrt) { - name = "sqrt"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdFMA) { + if (fn_id == ZigLLVMFnIdFMA) { name = "fma"; num_args = 3; + } else if (fn_id == ZigLLVMFnIdFloatOp) { + name = float_op_to_name(op, true); + num_args = 1; } else { zig_unreachable(); } @@ -2480,22 +2475,17 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, return result; } -static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) +static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) { + if ((op == BuiltinFnIdCeil || + op == BuiltinFnIdFloor) && + type_entry->id == ZigTypeIdInt) return val; + assert(type_entry->id == ZigTypeIdFloat); - LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor); + LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op); return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); } -static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) - return val; - - LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil); - return LLVMBuildCall(g->builder, ceil_fn, &val, 1, ""); -} - enum DivKind { DivKindFloat, DivKindTrunc, @@ -2571,7 +2561,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return result; case DivKindExact: if (want_runtime_safety) { - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); @@ -2593,12 +2583,12 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); LLVMPositionBuilderAtEnd(g->builder, ltz_block); - LLVMValueRef ceiled = gen_ceil(g, result, type_entry); + LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil); LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); LLVMPositionBuilderAtEnd(g->builder, gez_block); - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); @@ -2610,7 +2600,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return phi; } case DivKindFloor: - return gen_floor(g, result, type_entry); + return gen_float_op(g, result, type_entry, BuiltinFnIdFloor); } zig_unreachable(); } @@ -5450,10 +5440,10 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e return nullptr; } -static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) { - LLVMValueRef op = ir_llvm_value(g, instruction->op); +static LLVMValueRef ir_render_float_op(CodeGen *g, IrExecutable *executable, IrInstructionFloatOp *instruction) { + LLVMValueRef op = ir_llvm_value(g, instruction->op1); assert(instruction->base.value.type->id == ZigTypeIdFloat); - LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFloatOp, instruction->op); return LLVMBuildCall(g->builder, fn_val, &op, 1, ""); } @@ -5463,7 +5453,7 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrIn LLVMValueRef op3 = ir_llvm_value(g, instruction->op3); assert(instruction->base.value.type->id == ZigTypeIdFloat || instruction->base.value.type->id == ZigTypeIdVector); - LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd); LLVMValueRef args[3] = { op1, op2, @@ -5814,8 +5804,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction); - case IrInstructionIdSqrt: - return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction); + case IrInstructionIdFloatOp: + return ir_render_float_op(g, executable, (IrInstructionFloatOp *)instruction); case IrInstructionIdMulAdd: return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction); case IrInstructionIdArrayToVector: @@ -7435,6 +7425,20 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2); + create_builtin_fn(g, BuiltinFnIdSin, "sin", 2); + create_builtin_fn(g, BuiltinFnIdCos, "cos", 2); + create_builtin_fn(g, BuiltinFnIdExp, "exp", 2); + create_builtin_fn(g, BuiltinFnIdExp2, "exp2", 2); + create_builtin_fn(g, BuiltinFnIdLn, "ln", 2); + create_builtin_fn(g, BuiltinFnIdLog2, "log2", 2); + create_builtin_fn(g, BuiltinFnIdLog10, "log10", 2); + create_builtin_fn(g, BuiltinFnIdFabs, "fabs", 2); + create_builtin_fn(g, BuiltinFnIdFloor, "floor", 2); + create_builtin_fn(g, BuiltinFnIdCeil, "ceil", 2); + create_builtin_fn(g, BuiltinFnIdTrunc, "trunc", 2); + //Needs library support on Windows + //create_builtin_fn(g, BuiltinFnIdNearbyInt, "nearbyInt", 2); + create_builtin_fn(g, BuiltinFnIdRound, "round", 2); create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4); create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX); diff --git a/src/ir.cpp b/src/ir.cpp index c2c6cb615416..50d2a0686809 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -991,8 +991,8 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP return IrInstructionIdMarkErrRetTracePtr; } -static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) { - return IrInstructionIdSqrt; +static constexpr IrInstructionId ir_instruction_id(IrInstructionFloatOp *) { + return IrInstructionIdFloatOp; } static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckRuntimeScope *) { @@ -2312,6 +2312,59 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } + +//TODO Powi, Pow, minnum, maxnum, maximum, minimum, copysign, +// lround, llround, lrint, llrint +// So far this is only non-complicated type functions. +const char *float_op_to_name(BuiltinFnId op, bool llvm_name) { + const bool b = llvm_name; + + switch (op) { + case BuiltinFnIdSqrt: + return "sqrt"; + case BuiltinFnIdSin: + return "sin"; + case BuiltinFnIdCos: + return "cos"; + case BuiltinFnIdExp: + return "exp"; + case BuiltinFnIdExp2: + return "exp2"; + case BuiltinFnIdLn: + return b ? "log" : "ln"; + case BuiltinFnIdLog10: + return "log10"; + case BuiltinFnIdLog2: + return "log2"; + case BuiltinFnIdFabs: + return "fabs"; + case BuiltinFnIdFloor: + return "floor"; + case BuiltinFnIdCeil: + return "ceil"; + case BuiltinFnIdTrunc: + return "trunc"; + case BuiltinFnIdNearbyInt: + return b ? "nearbyint" : "nearbyInt"; + case BuiltinFnIdRound: + return "round"; + default: + zig_unreachable(); + } +} + +static IrInstruction *ir_build_float_op(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op1, BuiltinFnId op) { + IrInstructionFloatOp *instruction = ir_build_instruction(irb, scope, source_node); + instruction->type = type; + instruction->op1 = op1; + instruction->op = op; + + if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block); + ir_ref_instruction(op1, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) { IrInstructionMulAdd *instruction = ir_build_instruction(irb, scope, source_node); @@ -3033,17 +3086,6 @@ static IrInstruction *ir_build_mark_err_ret_trace_ptr(IrBuilder *irb, Scope *sco return &instruction->base; } -static IrInstruction *ir_build_sqrt(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op) { - IrInstructionSqrt *instruction = ir_build_instruction(irb, scope, source_node); - instruction->type = type; - instruction->op = op; - - if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block); - ir_ref_instruction(op, irb->current_basic_block); - - return &instruction->base; -} - static IrInstruction *ir_build_has_decl(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *container, IrInstruction *name) { @@ -4400,6 +4442,19 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo return ir_lval_wrap(irb, scope, bin_op, lval); } case BuiltinFnIdSqrt: + case BuiltinFnIdSin: + case BuiltinFnIdCos: + case BuiltinFnIdExp: + case BuiltinFnIdExp2: + case BuiltinFnIdLn: + case BuiltinFnIdLog2: + case BuiltinFnIdLog10: + case BuiltinFnIdFabs: + case BuiltinFnIdFloor: + case BuiltinFnIdCeil: + case BuiltinFnIdTrunc: + case BuiltinFnIdNearbyInt: + case BuiltinFnIdRound: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope); @@ -4411,7 +4466,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo if (arg1_value == irb->codegen->invalid_instruction) return arg1_value; - IrInstruction *ir_sqrt = ir_build_sqrt(irb, scope, node, arg0_value, arg1_value); + IrInstruction *ir_sqrt = ir_build_float_op(irb, scope, node, arg0_value, arg1_value, builtin_fn->id); return ir_lval_wrap(irb, scope, ir_sqrt, lval); } case BuiltinFnIdTruncate: @@ -23214,70 +23269,248 @@ static IrInstruction *ir_analyze_instruction_mark_err_ret_trace_ptr(IrAnalyze *i return result; } -static IrInstruction *ir_analyze_instruction_sqrt(IrAnalyze *ira, IrInstructionSqrt *instruction) { - ZigType *float_type = ir_resolve_type(ira, instruction->type->child); - if (type_is_invalid(float_type)) - return ira->codegen->invalid_instruction; +static void ir_eval_float_op(IrAnalyze *ira, IrInstructionFloatOp *source_instr, ZigType *float_type, + ConstExprValue *op, ConstExprValue *out_val) { + assert(ira && source_instr && float_type && out_val && op); + assert(float_type->id == ZigTypeIdFloat || + float_type->id == ZigTypeIdComptimeFloat); - IrInstruction *op = instruction->op->child; - if (type_is_invalid(op->value.type)) + BuiltinFnId fop = source_instr->op; + unsigned bits; + + if (float_type->id == ZigTypeIdComptimeFloat) { + bits = 128; + } else if (float_type->id == ZigTypeIdFloat) + bits = float_type->data.floating.bit_count; + + switch (bits) { + case 16: { + switch (fop) { + case BuiltinFnIdSqrt: + out_val->data.x_f16 = f16_sqrt(op->data.x_f16); + break; + case BuiltinFnIdSin: + case BuiltinFnIdCos: + case BuiltinFnIdExp: + case BuiltinFnIdExp2: + case BuiltinFnIdLn: + case BuiltinFnIdLog10: + case BuiltinFnIdLog2: + case BuiltinFnIdFabs: + case BuiltinFnIdFloor: + case BuiltinFnIdCeil: + case BuiltinFnIdTrunc: + case BuiltinFnIdNearbyInt: + case BuiltinFnIdRound: + zig_panic("unimplemented f16 builtin"); + default: + zig_unreachable(); + }; + break; + }; + case 32: { + switch (fop) { + case BuiltinFnIdSqrt: + out_val->data.x_f32 = sqrtf(op->data.x_f32); + break; + case BuiltinFnIdSin: + out_val->data.x_f32 = sinf(op->data.x_f32); + break; + case BuiltinFnIdCos: + out_val->data.x_f32 = cosf(op->data.x_f32); + break; + case BuiltinFnIdExp: + out_val->data.x_f32 = expf(op->data.x_f32); + break; + case BuiltinFnIdExp2: + out_val->data.x_f32 = exp2f(op->data.x_f32); + break; + case BuiltinFnIdLn: + out_val->data.x_f32 = logf(op->data.x_f32); + break; + case BuiltinFnIdLog10: + out_val->data.x_f32 = log10f(op->data.x_f32); + break; + case BuiltinFnIdLog2: + out_val->data.x_f32 = log2f(op->data.x_f32); + break; + case BuiltinFnIdFabs: + out_val->data.x_f32 = fabsf(op->data.x_f32); + break; + case BuiltinFnIdFloor: + out_val->data.x_f32 = floorf(op->data.x_f32); + break; + case BuiltinFnIdCeil: + out_val->data.x_f32 = ceilf(op->data.x_f32); + break; + case BuiltinFnIdTrunc: + out_val->data.x_f32 = truncf(op->data.x_f32); + break; + case BuiltinFnIdNearbyInt: + out_val->data.x_f32 = nearbyintf(op->data.x_f32); + break; + case BuiltinFnIdRound: + out_val->data.x_f32 = roundf(op->data.x_f32); + break; + default: + zig_unreachable(); + }; + break; + }; + case 64: { + switch (fop) { + case BuiltinFnIdSqrt: + out_val->data.x_f64 = sqrt(op->data.x_f64); + break; + case BuiltinFnIdSin: + out_val->data.x_f64 = sin(op->data.x_f64); + break; + case BuiltinFnIdCos: + out_val->data.x_f64 = cos(op->data.x_f64); + break; + case BuiltinFnIdExp: + out_val->data.x_f64 = exp(op->data.x_f64); + break; + case BuiltinFnIdExp2: + out_val->data.x_f64 = exp2(op->data.x_f64); + break; + case BuiltinFnIdLn: + out_val->data.x_f64 = log(op->data.x_f64); + break; + case BuiltinFnIdLog10: + out_val->data.x_f64 = log10(op->data.x_f64); + break; + case BuiltinFnIdLog2: + out_val->data.x_f64 = log2(op->data.x_f64); + break; + case BuiltinFnIdFabs: + out_val->data.x_f64 = fabs(op->data.x_f64); + break; + case BuiltinFnIdFloor: + out_val->data.x_f64 = floor(op->data.x_f64); + break; + case BuiltinFnIdCeil: + out_val->data.x_f64 = ceil(op->data.x_f64); + break; + case BuiltinFnIdTrunc: + out_val->data.x_f64 = trunc(op->data.x_f64); + break; + case BuiltinFnIdNearbyInt: + out_val->data.x_f64 = nearbyint(op->data.x_f64); + break; + case BuiltinFnIdRound: + out_val->data.x_f64 = round(op->data.x_f64); + break; + default: + zig_unreachable(); + } + break; + }; + case 128: { + float128_t *out, *in; + if (float_type->id == ZigTypeIdComptimeFloat) { + out = &out_val->data.x_bigfloat.value; + in = &op->data.x_bigfloat.value; + } else { + out = &out_val->data.x_f128; + in = &op->data.x_f128; + } + switch (fop) { + case BuiltinFnIdSqrt: + f128M_sqrt(in, out); + break; + case BuiltinFnIdNearbyInt: + case BuiltinFnIdSin: + case BuiltinFnIdCos: + case BuiltinFnIdExp: + case BuiltinFnIdExp2: + case BuiltinFnIdLn: + case BuiltinFnIdLog10: + case BuiltinFnIdLog2: + case BuiltinFnIdFabs: + case BuiltinFnIdFloor: + case BuiltinFnIdCeil: + case BuiltinFnIdTrunc: + case BuiltinFnIdRound: + zig_panic("unimplemented f128 builtin"); + default: + zig_unreachable(); + } + break; + }; + default: + zig_unreachable(); + } +} + +static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstructionFloatOp *instruction) { + IrInstruction *type = instruction->type->child; + if (type_is_invalid(type->value.type)) + return ira->codegen->invalid_instruction; + + ZigType *expr_type = ir_resolve_type(ira, type); + if (type_is_invalid(expr_type)) return ira->codegen->invalid_instruction; - bool ok_type = float_type->id == ZigTypeIdComptimeFloat || float_type->id == ZigTypeIdFloat; - if (!ok_type) { - ir_add_error(ira, instruction->type, buf_sprintf("@sqrt does not support type '%s'", buf_ptr(&float_type->name))); + // Only allow float types, and vectors of floats. + ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type; + if (float_type->id != ZigTypeIdFloat && float_type->id != ZigTypeIdComptimeFloat) { + ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name))); return ira->codegen->invalid_instruction; } - IrInstruction *casted_op = ir_implicit_cast(ira, op, float_type); - if (type_is_invalid(casted_op->value.type)) + IrInstruction *op1 = instruction->op1->child; + if (type_is_invalid(op1->value.type)) return ira->codegen->invalid_instruction; - if (instr_is_comptime(casted_op)) { - ConstExprValue *val = ir_resolve_const(ira, casted_op, UndefBad); - if (!val) + IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, float_type); + if (type_is_invalid(casted_op1->value.type)) + return ira->codegen->invalid_instruction; + + if (instr_is_comptime(casted_op1)) { + // Our comptime 16-bit and 128-bit support is quite limited. + if ((float_type->id == ZigTypeIdComptimeFloat || + float_type->data.floating.bit_count == 16 || + float_type->data.floating.bit_count == 128) && + instruction->op != BuiltinFnIdSqrt) { + ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name))); return ira->codegen->invalid_instruction; + } - IrInstruction *result = ir_const(ira, &instruction->base, float_type); + ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad); + if (!op1_const) + return ira->codegen->invalid_instruction; + + IrInstruction *result = ir_const(ira, &instruction->base, expr_type); ConstExprValue *out_val = &result->value; - if (float_type->id == ZigTypeIdComptimeFloat) { - bigfloat_sqrt(&out_val->data.x_bigfloat, &val->data.x_bigfloat); - } else if (float_type->id == ZigTypeIdFloat) { - switch (float_type->data.floating.bit_count) { - case 16: - out_val->data.x_f16 = f16_sqrt(val->data.x_f16); - break; - case 32: - out_val->data.x_f32 = sqrtf(val->data.x_f32); - break; - case 64: - out_val->data.x_f64 = sqrt(val->data.x_f64); - break; - case 128: - f128M_sqrt(&val->data.x_f128, &out_val->data.x_f128); - break; - default: - zig_unreachable(); + if (expr_type->id == ZigTypeIdVector) { + expand_undef_array(ira->codegen, op1_const); + out_val->special = ConstValSpecialUndef; + expand_undef_array(ira->codegen, out_val); + size_t len = expr_type->data.vector.len; + for (size_t i = 0; i < len; i += 1) { + ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i]; + ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i]; + assert(float_operand_op1->type == float_type); + assert(float_out_val->type == float_type); + ir_eval_float_op(ira, instruction, float_type, + op1_const, float_out_val); + float_out_val->type = float_type; } + out_val->type = expr_type; + out_val->special = ConstValSpecialStatic; } else { - zig_unreachable(); + ir_eval_float_op(ira, instruction, float_type, op1_const, out_val); } - return result; } ir_assert(float_type->id == ZigTypeIdFloat, &instruction->base); - if (float_type->data.floating.bit_count != 16 && - float_type->data.floating.bit_count != 32 && - float_type->data.floating.bit_count != 64) { - ir_add_error(ira, instruction->type, buf_sprintf("compiler TODO: add implementation of sqrt for '%s'", buf_ptr(&float_type->name))); - return ira->codegen->invalid_instruction; - } - IrInstruction *result = ir_build_sqrt(&ira->new_irb, instruction->base.scope, - instruction->base.source_node, nullptr, casted_op); - result->value.type = float_type; + IrInstruction *result = ir_build_float_op(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, nullptr, casted_op1, instruction->op); + result->value.type = expr_type; return result; } @@ -23762,8 +23995,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio return ir_analyze_instruction_merge_err_ret_traces(ira, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction); - case IrInstructionIdSqrt: - return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction); + case IrInstructionIdFloatOp: + return ir_analyze_instruction_float_op(ira, (IrInstructionFloatOp *)instruction); case IrInstructionIdMulAdd: return ir_analyze_instruction_mul_add(ira, (IrInstructionMulAdd *)instruction); case IrInstructionIdIntToErr: @@ -24004,7 +24237,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCoroFree: case IrInstructionIdCoroPromise: case IrInstructionIdPromiseResultType: - case IrInstructionIdSqrt: + case IrInstructionIdFloatOp: case IrInstructionIdMulAdd: case IrInstructionIdAtomicLoad: case IrInstructionIdIntCast: diff --git a/src/ir.hpp b/src/ir.hpp index 4fb75522122f..597624e2e674 100644 --- a/src/ir.hpp +++ b/src/ir.hpp @@ -26,5 +26,6 @@ bool ir_has_side_effects(IrInstruction *instruction); struct IrAnalyze; ConstExprValue *const_ptr_pointee(IrAnalyze *ira, CodeGen *codegen, ConstExprValue *const_val, AstNode *source_node); +const char *float_op_to_name(BuiltinFnId op, bool llvm_name); #endif diff --git a/src/ir_print.cpp b/src/ir_print.cpp index e205c8e067c0..165d9b473946 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1427,15 +1427,16 @@ static void ir_print_mark_err_ret_trace_ptr(IrPrint *irp, IrInstructionMarkErrRe fprintf(irp->f, ")"); } -static void ir_print_sqrt(IrPrint *irp, IrInstructionSqrt *instruction) { - fprintf(irp->f, "@sqrt("); +static void ir_print_float_op(IrPrint *irp, IrInstructionFloatOp *instruction) { + + fprintf(irp->f, "@%s(", float_op_to_name(instruction->op, false)); if (instruction->type != nullptr) { ir_print_other_instruction(irp, instruction->type); } else { fprintf(irp->f, "null"); } fprintf(irp->f, ","); - ir_print_other_instruction(irp, instruction->op); + ir_print_other_instruction(irp, instruction->op1); fprintf(irp->f, ")"); } @@ -1918,8 +1919,8 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdMarkErrRetTracePtr: ir_print_mark_err_ret_trace_ptr(irp, (IrInstructionMarkErrRetTracePtr *)instruction); break; - case IrInstructionIdSqrt: - ir_print_sqrt(irp, (IrInstructionSqrt *)instruction); + case IrInstructionIdFloatOp: + ir_print_float_op(irp, (IrInstructionFloatOp *)instruction); break; case IrInstructionIdMulAdd: ir_print_mul_add(irp, (IrInstructionMulAdd *)instruction); diff --git a/src/util.cpp b/src/util.cpp index 9a6a3829934d..f85565806f2b 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -13,6 +13,7 @@ #include "userland.h" void zig_panic(const char *format, ...) { + abort(); va_list ap; va_start(ap, format); vfprintf(stderr, format, ap); diff --git a/std/special/c.zig b/std/special/c.zig index b3cf54619fec..15cefbd2a097 100644 --- a/std/special/c.zig +++ b/std/special/c.zig @@ -254,24 +254,32 @@ export fn fmod(x: f64, y: f64) f64 { // TODO add intrinsics for these (and probably the double version too) // and have the math stuff use the intrinsic. same as @mod and @rem -export fn floorf(x: f32) f32 { - return math.floor(x); -} -export fn ceilf(x: f32) f32 { - return math.ceil(x); -} -export fn floor(x: f64) f64 { - return math.floor(x); -} -export fn ceil(x: f64) f64 { - return math.ceil(x); -} -export fn fma(a: f64, b: f64, c: f64) f64 { - return math.fma(f64, a, b, c); -} -export fn fmaf(a: f32, b: f32, c: f32) f32 { - return math.fma(f32, a, b, c); -} +export fn floorf(x: f32) f32 {return math.floor(x);} +export fn ceilf(x: f32) f32 {return math.ceil(x);} +export fn floor(x: f64) f64 {return math.floor(x);} +export fn ceil(x: f64) f64 {return math.ceil(x);} +export fn fma(a: f64, b: f64, c: f64) f64 {return math.fma(f64, a, b, c);} +export fn fmaf(a: f32, b: f32, c: f32) f32 {return math.fma(f32, a, b, c);} +export fn sin(a: f64) f64 {return math.sin(a);} +export fn sinf(a: f32) f32 {return math.sin(a);} +export fn cos(a: f64) f64 {return math.cos(a);} +export fn cosf(a: f32) f32 {return math.cos(a);} +export fn exp(a: f64) f64 {return math.exp(a);} +export fn expf(a: f32) f32 {return math.exp(a);} +export fn exp2(a: f64) f64 {return math.exp2(a);} +export fn exp2f(a: f32) f32 {return math.exp2(a);} +export fn log(a: f64) f64 {return math.ln(a);} +export fn logf(a: f32) f32 {return math.ln(a);} +export fn log2(a: f64) f64 {return math.log2(a);} +export fn log2f(a: f32) f32 {return math.log2(a);} +export fn log10(a: f64) f64 {return math.log10(a);} +export fn log10f(a: f32) f32 {return math.log10(a);} +export fn fabs(a: f64) f64 {return math.fabs(a);} +export fn fabsf(a: f32) f32 {return math.fabs(a);} +export fn trunc(a: f64) f64 {return math.trunc(a);} +export fn truncf(a: f32) f32 {return math.trunc(a);} +export fn round(a: f64) f64 {return math.round(a);} +export fn roundf(a: f32) f32 {return math.round(a);} fn generic_fmod(comptime T: type, x: T, y: T) T { @setRuntimeSafety(false); diff --git a/test/stage1/behavior.zig b/test/stage1/behavior.zig index 10e7c1a09beb..efefed33ba34 100644 --- a/test/stage1/behavior.zig +++ b/test/stage1/behavior.zig @@ -71,6 +71,7 @@ comptime { _ = @import("behavior/pointers.zig"); _ = @import("behavior/popcount.zig"); _ = @import("behavior/muladd.zig"); + _ = @import("behavior/floatop.zig"); _ = @import("behavior/ptrcast.zig"); _ = @import("behavior/pub_enum.zig"); _ = @import("behavior/ref_var_in_if_after_if_2nd_switch_prong.zig"); diff --git a/test/stage1/behavior/floatop.zig b/test/stage1/behavior/floatop.zig new file mode 100644 index 000000000000..de2f6815a623 --- /dev/null +++ b/test/stage1/behavior/floatop.zig @@ -0,0 +1,243 @@ +const expect = @import("std").testing.expect; +const pi = @import("std").math.pi; +const e = @import("std").math.e; + +test "@sqrt" { + comptime testSqrt(); + testSqrt(); +} + +fn testSqrt() void { + { + var a: f16 = 4; + expect(@sqrt(f16, a) == 2); + } + { + var a: f32 = 9; + expect(@sqrt(f32, a) == 3); + } + { + var a: f64 = 25; + expect(@sqrt(f64, a) == 5); + } + { + const a: comptime_float = 25.0; + expect(@sqrt(comptime_float, a) == 5.0); + } + // Waiting on a c.zig implementation + //{ + // var a: f128 = 49; + // expect(@sqrt(f128, a) == 7); + //} +} + +test "@sin" { + comptime testSin(); + testSin(); +} + +fn testSin() void { + // TODO - this is actually useful and should be implemented + // (all the trig functions for f16) + // but will probably wait till self-hosted + //{ + // var a: f16 = pi; + // expect(@sin(f16, a/2) == 1); + //} + { + var a: f32 = 0; + expect(@sin(f32, a) == 0); + } + { + var a: f64 = 0; + expect(@sin(f64, a) == 0); + } + // TODO + //{ + // var a: f16 = pi; + // expect(@sqrt(f128, a/2) == 1); + //} +} + +test "@cos" { + comptime testCos(); + testCos(); +} + +fn testCos() void { + { + var a: f32 = 0; + expect(@cos(f32, a) == 1); + } + { + var a: f64 = 0; + expect(@cos(f64, a) == 1); + } +} + +test "@exp" { + comptime testExp(); + testExp(); +} + +fn testExp() void { + { + var a: f32 = 0; + expect(@exp(f32, a) == 1); + } + { + var a: f64 = 0; + expect(@exp(f64, a) == 1); + } +} + +test "@exp2" { + comptime testExp2(); + testExp2(); +} + +fn testExp2() void { + { + var a: f32 = 2; + expect(@exp2(f32, a) == 4); + } + { + var a: f64 = 2; + expect(@exp2(f64, a) == 4); + } +} + +test "@ln" { + // Old musl (and glibc?), and our current math.ln implementation do not return 1 + // so also accept those values. + comptime testLn(); + testLn(); +} + +fn testLn() void { + { + var a: f32 = e; + expect(@ln(f32, a) == 1 or @ln(f32, a) == @bitCast(f32, u32(0x3f7fffff))); + } + { + var a: f64 = e; + expect(@ln(f64, a) == 1 or @ln(f64, a) == @bitCast(f64, u64(0x3ff0000000000000))); + } +} + +test "@log2" { + comptime testLog2(); + testLog2(); +} + +fn testLog2() void { + { + var a: f32 = 4; + expect(@log2(f32, a) == 2); + } + { + var a: f64 = 4; + expect(@log2(f64, a) == 2); + } +} + +test "@log10" { + comptime testLog10(); + testLog10(); +} + +fn testLog10() void { + { + var a: f32 = 100; + expect(@log10(f32, a) == 2); + } + { + var a: f64 = 1000; + expect(@log10(f64, a) == 3); + } +} + +test "@fabs" { + comptime testFabs(); + testFabs(); +} + +fn testFabs() void { + { + var a: f32 = -2.5; + var b: f32 = 2.5; + expect(@fabs(f32, a) == 2.5); + expect(@fabs(f32, b) == 2.5); + } + { + var a: f64 = -2.5; + var b: f64 = 2.5; + expect(@fabs(f64, a) == 2.5); + expect(@fabs(f64, b) == 2.5); + } +} + +test "@floor" { + comptime testFloor(); + testFloor(); +} + +fn testFloor() void { + { + var a: f32 = 2.1; + expect(@floor(f32, a) == 2); + } + { + var a: f64 = 3.5; + expect(@floor(f64, a) == 3); + } +} + +test "@ceil" { + comptime testCeil(); + testCeil(); +} + +fn testCeil() void { + { + var a: f32 = 2.1; + expect(@ceil(f32, a) == 3); + } + { + var a: f64 = 3.5; + expect(@ceil(f64, a) == 4); + } +} + +test "@trunc" { + comptime testTrunc(); + testTrunc(); +} + +fn testTrunc() void { + { + var a: f32 = 2.1; + expect(@trunc(f32, a) == 2); + } + { + var a: f64 = -3.5; + expect(@trunc(f64, a) == -3); + } +} + +// This is waiting on library support for the Windows build (not sure why the other's don't need it) +//test "@nearbyInt" { +// comptime testNearbyInt(); +// testNearbyInt(); +//} + +//fn testNearbyInt() void { +// { +// var a: f32 = 2.1; +// expect(@nearbyInt(f32, a) == 2); +// } +// { +// var a: f64 = -3.75; +// expect(@nearbyInt(f64, a) == -4); +// } +//}