From bbfb53d52411cc5b1f560293c757bff252e1e06f Mon Sep 17 00:00:00 2001
From: Shawn Landden
{#syntax#}@mulAdd(comptime T: type, a: T, b: T, c: T) T{#endsyntax#}+
+ Fused multiply add (for floats), similar to {#syntax#}(a * b) + c{#endsyntax#}, except + only rounds once, and is thus more accurate. +
{#header_close#} {#header_open|@byteSwap#} diff --git a/src/all_types.hpp b/src/all_types.hpp index 5aa1c78ea1ef..83df71b95f48 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1406,6 +1406,7 @@ enum BuiltinFnId { BuiltinFnIdSubWithOverflow, BuiltinFnIdMulWithOverflow, BuiltinFnIdShlWithOverflow, + BuiltinFnIdMulAdd, BuiltinFnIdCInclude, BuiltinFnIdCDefine, BuiltinFnIdCUndef, @@ -1554,6 +1555,7 @@ enum ZigLLVMFnId { ZigLLVMFnIdClz, ZigLLVMFnIdPopCount, ZigLLVMFnIdOverflowArithmetic, + ZigLLVMFnIdFMA, ZigLLVMFnIdFloor, ZigLLVMFnIdCeil, ZigLLVMFnIdSqrt, @@ -1584,6 +1586,7 @@ struct ZigLLVMFnKey { } pop_count; struct { uint32_t bit_count; + uint32_t vector_len; // 0 means not a vector } floating; struct { AddSubMul add_sub_mul; @@ -2235,6 +2238,7 @@ enum IrInstructionId { IrInstructionIdHandle, IrInstructionIdAlignOf, IrInstructionIdOverflowOp, + IrInstructionIdMulAdd, IrInstructionIdTestErr, IrInstructionIdUnwrapErrCode, IrInstructionIdUnwrapErrPayload, @@ -3038,6 +3042,15 @@ struct IrInstructionOverflowOp { ZigType *result_ptr_type; }; +struct IrInstructionMulAdd { + IrInstruction base; + + IrInstruction *type_value; + IrInstruction *op1; + IrInstruction *op2; + IrInstruction *op3; +}; + struct IrInstructionAlignOf { IrInstruction base; diff --git a/src/analyze.cpp b/src/analyze.cpp index c7e35367c33e..bff740cd52cf 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5737,11 +5737,11 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { case ZigLLVMFnIdPopCount: return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049; case ZigLLVMFnIdFloor: - return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168; case ZigLLVMFnIdCeil: - return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1953839089; case ZigLLVMFnIdSqrt: - return (uint32_t)(x.data.floating.bit_count) * (uint32_t)2225366385; + case ZigLLVMFnIdFMA: + return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) + + (uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025); case ZigLLVMFnIdBswap: return (uint32_t)(x.data.bswap.bit_count) * (uint32_t)3661994335; case ZigLLVMFnIdBitReverse: @@ -5772,6 +5772,7 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { case ZigLLVMFnIdFloor: case ZigLLVMFnIdCeil: case ZigLLVMFnIdSqrt: + case ZigLLVMFnIdFMA: return a.data.floating.bit_count == b.data.floating.bit_count; case ZigLLVMFnIdOverflowArithmetic: return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) && diff --git a/src/codegen.cpp b/src/codegen.cpp index 3dd6995c61cc..6691652a5e04 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -807,31 +807,51 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu } static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) { - assert(type_entry->id == ZigTypeIdFloat); + assert(type_entry->id == ZigTypeIdFloat || + type_entry->id == ZigTypeIdVector); + + bool is_vector = (type_entry->id == ZigTypeIdVector); + ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry; ZigLLVMFnKey key = {}; key.id = fn_id; - key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count; + key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count; + key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) return existing_entry->value; const char *name; + uint32_t num_args; if (fn_id == ZigLLVMFnIdFloor) { name = "floor"; + num_args = 1; } else if (fn_id == ZigLLVMFnIdCeil) { name = "ceil"; + num_args = 1; } else if (fn_id == ZigLLVMFnIdSqrt) { name = "sqrt"; + num_args = 1; + } else if (fn_id == ZigLLVMFnIdFMA) { + name = "fma"; + num_args = 3; } else { zig_unreachable(); } char fn_name[64]; - sprintf(fn_name, "llvm.%s.f%" ZIG_PRI_usize "", name, type_entry->data.floating.bit_count); + if (is_vector) + sprintf(fn_name, "llvm.%s.v%" PRIu32 "f%" PRIu32, name, key.data.floating.vector_len, key.data.floating.bit_count); + else + sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count); LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry); - LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, &float_type_ref, 1, false); + LLVMTypeRef return_elem_types[3] = { + float_type_ref, + float_type_ref, + float_type_ref, + }; + LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false); LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type); assert(LLVMGetIntrinsicID(fn_val)); @@ -5437,6 +5457,21 @@ static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstr return LLVMBuildCall(g->builder, fn_val, &op, 1, ""); } +static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrInstructionMulAdd *instruction) { + LLVMValueRef op1 = ir_llvm_value(g, instruction->op1); + LLVMValueRef op2 = ir_llvm_value(g, instruction->op2); + LLVMValueRef op3 = ir_llvm_value(g, instruction->op3); + assert(instruction->base.value.type->id == ZigTypeIdFloat || + instruction->base.value.type->id == ZigTypeIdVector); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA); + LLVMValueRef args[3] = { + op1, + op2, + op3, + }; + return LLVMBuildCall(g->builder, fn_val, args, 3, ""); +} + static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) { LLVMValueRef op = ir_llvm_value(g, instruction->op); ZigType *int_type = instruction->base.value.type; @@ -5781,6 +5816,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction); case IrInstructionIdSqrt: return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction); + case IrInstructionIdMulAdd: + return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction); case IrInstructionIdArrayToVector: return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction); case IrInstructionIdVectorToArray: @@ -7398,6 +7435,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2); + create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4); create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX); diff --git a/src/ir.cpp b/src/ir.cpp index 5c09e48b2d7c..c2c6cb615416 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -747,6 +747,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestErr *) { return IrInstructionIdTestErr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionMulAdd *) { + return IrInstructionIdMulAdd; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionUnwrapErrCode *) { return IrInstructionIdUnwrapErrCode; } @@ -2308,6 +2312,22 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } +static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) { + IrInstructionMulAdd *instruction = ir_build_instruction{#syntax#}@sqrt(comptime T: type, value: T) T{#endsyntax#}
Performs the square root of a floating point number. Uses a dedicated hardware instruction - when available. Currently only supports f32 and f64 at runtime. f128 at runtime is TODO. + when available. Supports f16, f32, f64, and f128, as well as vectors.
+ {#header_close#} + {#header_open|@sin#} +{#syntax#}@sin(comptime T: type, value: T) T{#endsyntax#}+
+ Sine trigometric function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@cos#} +{#syntax#}@cos(comptime T: type, value: T) T{#endsyntax#}+
+ Cosine trigometric function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@exp#} +{#syntax#}@exp(comptime T: type, value: T) T{#endsyntax#}+
+ Base-e exponential function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@exp2#} +{#syntax#}@exp2(comptime T: type, value: T) T{#endsyntax#}+
+ Base-2 exponential function on a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@ln#} +{#syntax#}@ln(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the natural logarithm of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@log2#} +{#syntax#}@log2(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the logarithm to the base 2 of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@log10#} +{#syntax#}@log10(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the logarithm to the base 10 of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@fabs#} +{#syntax#}@fabs(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the absolute value of a floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@floor#} +{#syntax#}@floor(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the largest integral value not greater than the given floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@ceil#} +{#syntax#}@ceil(comptime T: type, value: T) T{#endsyntax#}+
+ Returns the largest integral value not less than the given floating point number. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@trunc#} +{#syntax#}@trunc(comptime T: type, value: T) T{#endsyntax#}+
+ Rounds the given floating point number to an integer, towards zero. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64. +
+ {#header_close#} + {#header_open|@round#} +{#syntax#}@round(comptime T: type, value: T) T{#endsyntax#}
- This is a low-level intrinsic. Most code can use {#syntax#}std.math.sqrt{#endsyntax#} instead. + Rounds the given floating point number to an integer, away from zero. Uses a dedicated hardware instruction + when available. Currently supports f32 and f64.
{#header_close#} diff --git a/src/all_types.hpp b/src/all_types.hpp index 83df71b95f48..6595218bcf4e 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1434,6 +1434,19 @@ enum BuiltinFnId { BuiltinFnIdRem, BuiltinFnIdMod, BuiltinFnIdSqrt, + BuiltinFnIdSin, + BuiltinFnIdCos, + BuiltinFnIdExp, + BuiltinFnIdExp2, + BuiltinFnIdLn, + BuiltinFnIdLog2, + BuiltinFnIdLog10, + BuiltinFnIdFabs, + BuiltinFnIdFloor, + BuiltinFnIdCeil, + BuiltinFnIdTrunc, + BuiltinFnIdNearbyInt, + BuiltinFnIdRound, BuiltinFnIdTruncate, BuiltinFnIdIntCast, BuiltinFnIdFloatCast, @@ -1556,9 +1569,7 @@ enum ZigLLVMFnId { ZigLLVMFnIdPopCount, ZigLLVMFnIdOverflowArithmetic, ZigLLVMFnIdFMA, - ZigLLVMFnIdFloor, - ZigLLVMFnIdCeil, - ZigLLVMFnIdSqrt, + ZigLLVMFnIdFloatOp, ZigLLVMFnIdBswap, ZigLLVMFnIdBitReverse, }; @@ -1585,6 +1596,7 @@ struct ZigLLVMFnKey { uint32_t bit_count; } pop_count; struct { + BuiltinFnId op; uint32_t bit_count; uint32_t vector_len; // 0 means not a vector } floating; @@ -2239,6 +2251,7 @@ enum IrInstructionId { IrInstructionIdAlignOf, IrInstructionIdOverflowOp, IrInstructionIdMulAdd, + IrInstructionIdFloatOp, IrInstructionIdTestErr, IrInstructionIdUnwrapErrCode, IrInstructionIdUnwrapErrPayload, @@ -2300,7 +2313,6 @@ enum IrInstructionId { IrInstructionIdAddImplicitReturnType, IrInstructionIdMergeErrRetTraces, IrInstructionIdMarkErrRetTracePtr, - IrInstructionIdSqrt, IrInstructionIdErrSetCast, IrInstructionIdToBytes, IrInstructionIdFromBytes, @@ -3474,11 +3486,13 @@ struct IrInstructionMarkErrRetTracePtr { IrInstruction *err_ret_trace_ptr; }; -struct IrInstructionSqrt { +// For float ops which take a single argument +struct IrInstructionFloatOp { IrInstruction base; + BuiltinFnId op; IrInstruction *type; - IrInstruction *op; + IrInstruction *op1; }; struct IrInstructionCheckRuntimeScope { diff --git a/src/analyze.cpp b/src/analyze.cpp index 15b42c7f9dad..13b35e0aff43 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5736,9 +5736,10 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) { return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817; case ZigLLVMFnIdPopCount: return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049; - case ZigLLVMFnIdFloor: - case ZigLLVMFnIdCeil: - case ZigLLVMFnIdSqrt: + case ZigLLVMFnIdFloatOp: + return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) + + (uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025) + + (uint32_t)(x.data.floating.op) * (uint32_t)43789879; case ZigLLVMFnIdFMA: return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) + (uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025); @@ -5769,10 +5770,10 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) { return a.data.bswap.bit_count == b.data.bswap.bit_count; case ZigLLVMFnIdBitReverse: return a.data.bit_reverse.bit_count == b.data.bit_reverse.bit_count; - case ZigLLVMFnIdFloor: - case ZigLLVMFnIdCeil: - case ZigLLVMFnIdSqrt: - return a.data.floating.bit_count == b.data.floating.bit_count; + case ZigLLVMFnIdFloatOp: + return a.data.floating.bit_count == b.data.floating.bit_count && + a.data.floating.vector_len == b.data.floating.vector_len && + a.data.floating.op == b.data.floating.op; case ZigLLVMFnIdFMA: return a.data.floating.bit_count == b.data.floating.bit_count && a.data.floating.vector_len == b.data.floating.vector_len; diff --git a/src/codegen.cpp b/src/codegen.cpp index 6691652a5e04..41caa29dbd55 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -806,7 +806,7 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu return fn_val; } -static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) { +static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id, BuiltinFnId op) { assert(type_entry->id == ZigTypeIdFloat || type_entry->id == ZigTypeIdVector); @@ -817,6 +817,7 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn key.id = fn_id; key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count; key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0; + key.data.floating.op = op; auto existing_entry = g->llvm_fn_table.maybe_get(key); if (existing_entry) @@ -824,18 +825,12 @@ static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn const char *name; uint32_t num_args; - if (fn_id == ZigLLVMFnIdFloor) { - name = "floor"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdCeil) { - name = "ceil"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdSqrt) { - name = "sqrt"; - num_args = 1; - } else if (fn_id == ZigLLVMFnIdFMA) { + if (fn_id == ZigLLVMFnIdFMA) { name = "fma"; num_args = 3; + } else if (fn_id == ZigLLVMFnIdFloatOp) { + name = float_op_to_name(op, true); + num_args = 1; } else { zig_unreachable(); } @@ -2480,22 +2475,17 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *type_entry, return result; } -static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) +static LLVMValueRef gen_float_op(CodeGen *g, LLVMValueRef val, ZigType *type_entry, BuiltinFnId op) { + if ((op == BuiltinFnIdCeil || + op == BuiltinFnIdFloor) && + type_entry->id == ZigTypeIdInt) return val; + assert(type_entry->id == ZigTypeIdFloat); - LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor); + LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloatOp, op); return LLVMBuildCall(g->builder, floor_fn, &val, 1, ""); } -static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, ZigType *type_entry) { - if (type_entry->id == ZigTypeIdInt) - return val; - - LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil); - return LLVMBuildCall(g->builder, ceil_fn, &val, 1, ""); -} - enum DivKind { DivKindFloat, DivKindTrunc, @@ -2571,7 +2561,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return result; case DivKindExact: if (want_runtime_safety) { - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail"); LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, ""); @@ -2593,12 +2583,12 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block); LLVMPositionBuilderAtEnd(g->builder, ltz_block); - LLVMValueRef ceiled = gen_ceil(g, result, type_entry); + LLVMValueRef ceiled = gen_float_op(g, result, type_entry, BuiltinFnIdCeil); LLVMBasicBlockRef ceiled_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); LLVMPositionBuilderAtEnd(g->builder, gez_block); - LLVMValueRef floored = gen_floor(g, result, type_entry); + LLVMValueRef floored = gen_float_op(g, result, type_entry, BuiltinFnIdFloor); LLVMBasicBlockRef floored_end_block = LLVMGetInsertBlock(g->builder); LLVMBuildBr(g->builder, end_block); @@ -2610,7 +2600,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast return phi; } case DivKindFloor: - return gen_floor(g, result, type_entry); + return gen_float_op(g, result, type_entry, BuiltinFnIdFloor); } zig_unreachable(); } @@ -5450,10 +5440,10 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e return nullptr; } -static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) { - LLVMValueRef op = ir_llvm_value(g, instruction->op); +static LLVMValueRef ir_render_float_op(CodeGen *g, IrExecutable *executable, IrInstructionFloatOp *instruction) { + LLVMValueRef op = ir_llvm_value(g, instruction->op1); assert(instruction->base.value.type->id == ZigTypeIdFloat); - LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFloatOp, instruction->op); return LLVMBuildCall(g->builder, fn_val, &op, 1, ""); } @@ -5463,7 +5453,7 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrIn LLVMValueRef op3 = ir_llvm_value(g, instruction->op3); assert(instruction->base.value.type->id == ZigTypeIdFloat || instruction->base.value.type->id == ZigTypeIdVector); - LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA); + LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA, BuiltinFnIdMulAdd); LLVMValueRef args[3] = { op1, op2, @@ -5814,8 +5804,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction); case IrInstructionIdMarkErrRetTracePtr: return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction); - case IrInstructionIdSqrt: - return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction); + case IrInstructionIdFloatOp: + return ir_render_float_op(g, executable, (IrInstructionFloatOp *)instruction); case IrInstructionIdMulAdd: return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction); case IrInstructionIdArrayToVector: @@ -7435,6 +7425,20 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdRem, "rem", 2); create_builtin_fn(g, BuiltinFnIdMod, "mod", 2); create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2); + create_builtin_fn(g, BuiltinFnIdSin, "sin", 2); + create_builtin_fn(g, BuiltinFnIdCos, "cos", 2); + create_builtin_fn(g, BuiltinFnIdExp, "exp", 2); + create_builtin_fn(g, BuiltinFnIdExp2, "exp2", 2); + create_builtin_fn(g, BuiltinFnIdLn, "ln", 2); + create_builtin_fn(g, BuiltinFnIdLog2, "log2", 2); + create_builtin_fn(g, BuiltinFnIdLog10, "log10", 2); + create_builtin_fn(g, BuiltinFnIdFabs, "fabs", 2); + create_builtin_fn(g, BuiltinFnIdFloor, "floor", 2); + create_builtin_fn(g, BuiltinFnIdCeil, "ceil", 2); + create_builtin_fn(g, BuiltinFnIdTrunc, "trunc", 2); + //Needs library support on Windows + //create_builtin_fn(g, BuiltinFnIdNearbyInt, "nearbyInt", 2); + create_builtin_fn(g, BuiltinFnIdRound, "round", 2); create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4); create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX); create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX); diff --git a/src/ir.cpp b/src/ir.cpp index c2c6cb615416..50d2a0686809 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -991,8 +991,8 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP return IrInstructionIdMarkErrRetTracePtr; } -static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) { - return IrInstructionIdSqrt; +static constexpr IrInstructionId ir_instruction_id(IrInstructionFloatOp *) { + return IrInstructionIdFloatOp; } static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckRuntimeScope *) { @@ -2312,6 +2312,59 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } + +//TODO Powi, Pow, minnum, maxnum, maximum, minimum, copysign, +// lround, llround, lrint, llrint +// So far this is only non-complicated type functions. +const char *float_op_to_name(BuiltinFnId op, bool llvm_name) { + const bool b = llvm_name; + + switch (op) { + case BuiltinFnIdSqrt: + return "sqrt"; + case BuiltinFnIdSin: + return "sin"; + case BuiltinFnIdCos: + return "cos"; + case BuiltinFnIdExp: + return "exp"; + case BuiltinFnIdExp2: + return "exp2"; + case BuiltinFnIdLn: + return b ? "log" : "ln"; + case BuiltinFnIdLog10: + return "log10"; + case BuiltinFnIdLog2: + return "log2"; + case BuiltinFnIdFabs: + return "fabs"; + case BuiltinFnIdFloor: + return "floor"; + case BuiltinFnIdCeil: + return "ceil"; + case BuiltinFnIdTrunc: + return "trunc"; + case BuiltinFnIdNearbyInt: + return b ? "nearbyint" : "nearbyInt"; + case BuiltinFnIdRound: + return "round"; + default: + zig_unreachable(); + } +} + +static IrInstruction *ir_build_float_op(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op1, BuiltinFnId op) { + IrInstructionFloatOp *instruction = ir_build_instruction