From 23b84f2421cd6c09d330b47bdf4aed7944c11a6c Mon Sep 17 00:00:00 2001 From: Zentrik Date: Sun, 20 Aug 2023 09:22:34 +0100 Subject: [PATCH] Add some more fastmath functions --- src/device/intrinsics/math.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 7e3d145d9d..6b07d1481c 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -131,6 +131,7 @@ end @device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) # TODO: enable once PTX > 7.0 is supported # @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) @@ -221,6 +222,7 @@ end @device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) @device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) @device_function rsqrt(x::Float32) = ccall("extern __nv_rsqrtf", llvmcall, Cfloat, (Cfloat,), x) @@ -306,6 +308,8 @@ end @device_override FastMath.div_fast(x::Float32, y::Float32) = ccall("extern __nv_fast_fdividef", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) +@device_override Base.inv(x::Float32) = ccall("extern __nv_frcp_rn", llvmcall, Cfloat, (Cfloat,), x) +@device_override FastMath.inv_fast(x::Union{Float32, Float64}) = @fastmath one(x) / x ## distributions