Skip to content

Commit

Permalink
add fast math options to intrinsics and hook into fastmath macro (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kristoffer Carlsson authored Mar 4, 2020
1 parent 8204863 commit bdfd585
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 42 deletions.
128 changes: 92 additions & 36 deletions src/LLVM_intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ module Intrinsics
# when passed to LLVM. It is up to the caller to make sure that the correct
# intrinsic is called (e.g uitofp vs sitofp).

# TODO: fastmath flags

import ..SIMD: SIMD, VE, LVec, FloatingTypes
# Inlcude Bool in IntegerTypes
const IntegerTypes = Union{SIMD.IntegerTypes, Bool}
Expand Down Expand Up @@ -53,6 +51,39 @@ llvm_name(llvmf, ::Type{T}) where {T} = string("llvm", ".", dotit(llv
llvm_type(::Type{T}) where {T} = d[T]
llvm_type(::Type{LVec{N, T}}) where {N,T} = "< $N x $(d[T])>"

############
# FastMath #
############

module FastMath
const nnan = 1 << 0
const ninf = 1 << 1
const nsz = 1 << 2
const arcp = 1 << 3
const contract = 1 << 4
const afn = 1 << 5
const reassoc = 1 << 6
const fast = 1 << 7
end

struct FastMathFlags{T} end
Base.@pure FastMathFlags(T::Int) = FastMathFlags{T}()

function fp_str(::Type{FastMathFlags{T}}) where {T}
flags = String[]
(T & FastMath.nnan != 0) && push!(flags, "nnan")
(T & FastMath.ninf != 0) && push!(flags, "ninf")
(T & FastMath.nsz != 0) && push!(flags, "nsz")
(T & FastMath.arcp != 0) && push!(flags, "arcp")
(T & FastMath.contract != 0) && push!(flags, "contract")
(T & FastMath.afn != 0) && push!(flags, "afn")
(T & FastMath.reassoc != 0) && push!(flags, "reassoc")
(T & FastMath.fast != 0) && push!(flags, "fast")
return join(flags, " ")
end
fp_str(::Type{Nothing}) = ""

const FPFlags{T} = Union{Nothing, FastMathFlags{T}}

####################
# Unary operators #
Expand Down Expand Up @@ -101,9 +132,10 @@ for (fs, c) in zip([UNARY_INTRINSICS_FLOAT, UNARY_INTRINSICS_INT],
end

# fneg (not an intrinsic so cannot use `ccall)
@generated function fneg(x::T) where T<:LT{<:FloatingTypes}
@generated function fneg(x::T, ::F=nothing) where {T<:LT{<:FloatingTypes}, F<:FPFlags}
fpflags = fp_str(F)
s = """
%2 = fneg $(llvm_type(T)) %0
%2 = fneg $fpflags $(llvm_type(T)) %0
ret $(llvm_type(T)) %2
"""
return :(
Expand Down Expand Up @@ -140,20 +172,32 @@ const BINARY_OPS_INT = [
:xor
]

for (fs, c) in zip([BINARY_OPS_FLOAT, BINARY_OPS_INT],
[FloatingTypes, IntegerTypes])
for f in fs
@eval @generated function $f(x::T, y::T) where T<:LT{<:$c}
ff = $(QuoteNode(f))
s = """
%3 = $ff $(llvm_type(T)) %0, %1
ret $(llvm_type(T)) %3
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, T, Tuple{T, T}, x, y)
)
end
for f in BINARY_OPS_FLOAT
@eval @generated function $f(x::T, y::T, ::F=nothing) where {T<:LT{<:FloatingTypes}, F<:FPFlags}
fpflags = fp_str(F)
ff = $(QuoteNode(f))
s = """
%3 = $ff $fpflags $(llvm_type(T)) %0, %1
ret $(llvm_type(T)) %3
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, T, Tuple{T, T}, x, y)
)
end
end

for f in BINARY_OPS_INT
@eval @generated function $f(x::T, y::T) where T<:LT{<:IntegerTypes}
ff = $(QuoteNode(f))
s = """
%3 = $ff $(llvm_type(T)) %0, %1
ret $(llvm_type(T)) %3
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, T, Tuple{T, T}, x, y)
)
end
end

Expand Down Expand Up @@ -279,24 +323,36 @@ const CMP_FLAGS_INT = [
:ule
]

for (f, c, flags) in zip(["fcmp", "icmp"],
[FloatingTypes, IntegerTypes],
[CMP_FLAGS_FLOAT, CMP_FLAGS_INT])
for flag in flags
ftot = Symbol(string(f, "_", flag))
@eval @generated function $ftot(x::LVec{N, T}, y::LVec{N, T}) where {N, T <: $c}
fflag = $(QuoteNode(flag))
ff = $(QuoteNode(f))
s = """
%res = $ff $(fflag) <$(N) x $(d[T])> %0, %1
%resb = zext <$(N) x i1> %res to <$(N) x i8>
ret <$(N) x i8> %resb
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, LVec{N, Bool}, Tuple{LVec{N, T}, LVec{N, T}}, x, y)
)
end
for flag in CMP_FLAGS_FLOAT
ftot = Symbol(string("fcmp_", flag))
@eval @generated function $ftot(x::LVec{N, T}, y::LVec{N, T}, ::F=nothing) where {N, T <: FloatingTypes, F<:FPFlags}
fpflags = fp_str(F)
fflag = $(QuoteNode(flag))
s = """
%res = fcmp $(fpflags) $(fflag) <$(N) x $(d[T])> %0, %1
%resb = zext <$(N) x i1> %res to <$(N) x i8>
ret <$(N) x i8> %resb
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, LVec{N, Bool}, Tuple{LVec{N, T}, LVec{N, T}}, x, y)
)
end
end

for flag in CMP_FLAGS_INT
ftot = Symbol(string("icmp_", flag))
@eval @generated function $ftot(x::LVec{N, T}, y::LVec{N, T}) where {N, T <: IntegerTypes}
fflag = $(QuoteNode(flag))
s = """
%res = icmp $(fflag) <$(N) x $(d[T])> %0, %1
%resb = zext <$(N) x i1> %res to <$(N) x i8>
ret <$(N) x i8> %resb
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, LVec{N, Bool}, Tuple{LVec{N, T}, LVec{N, T}}, x, y)
)
end
end

Expand Down
33 changes: 27 additions & 6 deletions src/simdvec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Base.reinterpret(::Type{Vec{N, T}}, v::Vec) where {T, N} = Vec(Intrinsics.bitcas
Base.reinterpret(::Type{Vec{N, T}}, v::ScalarTypes) where {T, N} = Vec(Intrinsics.bitcast(Intrinsics.LVec{N, T}, v))
Base.reinterpret(::Type{T}, v::Vec) where {T} = Intrinsics.bitcast(T, v.data)

const FASTMATH = Intrinsics.FastMathFlags(Intrinsics.FastMath.fast)

###################
# Unary operators #
Expand Down Expand Up @@ -156,6 +157,7 @@ end
Base.:+(v::Vec{<:Any, <:ScalarTypes}) = v
Base.:-(v::Vec{<:Any, <:IntegerTypes}) = zero(v) - v
Base.:-(v::Vec{<:Any, <:FloatingTypes}) = Vec(Intrinsics.fneg(v.data))
Base.FastMath.sub_fast(v::Vec{<:Any, <:FloatingTypes}) = Vec(Intrinsics.fneg(v.data, FASTMATH))
Base.:~(v::Vec{N, T}) where {N, T<:IntegerTypes} = Vec(Intrinsics.xor(v.data, Vec{N, T}(-1).data))
Base.:~(v::Vec{N, Bool}) where {N} = Vec(Intrinsics.xor(v.data, Vec{N, Bool}(true).data))
Base.abs(v::Vec{N, T}) where {N, T} = Vec(vifelse(v < zero(T), -v, v))
Expand Down Expand Up @@ -238,14 +240,28 @@ const BINARY_OPS = [
(:(Base.:<=) , FloatingTypes , Intrinsics.fcmp_ole)
]

function get_fastmath_function(op)
if op isa Expr && op.head == Symbol(".") && op.args[1] == :Base &&
op.args[2].value in keys(Base.FastMath.fast_op)
return :(Base.FastMath.$(Base.FastMath.fast_op[op.args[2].value]))
end
return nothing
end

for (op, constraint, llvmop) in BINARY_OPS
@eval @inline function $op(x::Vec{N, T}, y::Vec{N, T}) where {N, T <: $constraint}
Vec($(llvmop)(x.data, y.data))
end

# Add a fast math version if applicable
if (fast_op = get_fastmath_function(op)) !== nothing
@eval @inline function $(fast_op)(x::Vec{N, T}, y::Vec{N, T}) where {N, T <: $constraint}
Vec($(llvmop)(x.data, y.data, FASTMATH))
end
end
end

# overflow

const OVERFLOW_INTRINSICS = [
(:(Base.Checked.add_with_overflow) , IntTypes , Intrinsics.sadd_with_overflow)
(:(Base.Checked.add_with_overflow) , UIntTypes , Intrinsics.uadd_with_overflow)
Expand All @@ -261,7 +277,6 @@ for (op, constraint, llvmop) in OVERFLOW_INTRINSICS
end
end


# max min
@inline Base.max(v1::Vec{N,T}, v2::Vec{N,T}) where {N,T<:IntegerTypes} =
Vec(vifelse(v1 >= v2, v1, v2))
Expand Down Expand Up @@ -372,11 +387,17 @@ for (op, constraint) in [BINARY_OPS;
(:(Base.Checked.mul_with_overflow) , IntTypes)
(:(Base.Checked.mul_with_overflow) , UIntTypes)
]
@eval @inline function $op(x::T2, y::Vec{N, T}) where {N, T2<:ScalarTypes, T <: $constraint}
$op(Vec{N, T}(x), y)
ops = [op]
if (fast_op = get_fastmath_function(op)) !== nothing
push!(ops, fast_op)
end
@eval @inline function $op(x::Vec{N, T}, y::T2) where {N, T2 <:ScalarTypes, T <: $constraint}
$op(x, Vec{N, T}(y))
for op in ops
@eval @inline function $op(x::T2, y::Vec{N, T}) where {N, T2<:ScalarTypes, T <: $constraint}
$op(Vec{N, T}(x), y)
end
@eval @inline function $op(x::Vec{N, T}, y::T2) where {N, T2 <:ScalarTypes, T <: $constraint}
$op(x, Vec{N, T}(y))
end
end
end

Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ llvm_ir(f, args) = sprint(code_llvm, f, Base.typesof(args...))
end
end

@testset "fastmath" begin
v = Vec(1.0,2.0,3.0,4.0)
@test all(Tuple(@fastmath v+v) .≈ Tuple(v+v))
@test all(Tuple(@fastmath v+1.0) .≈ Tuple(v+1.0))
@test all(Tuple(@fastmath 1.0+v) .≈ Tuple(1.0+v))
@test all(Tuple(@fastmath -v) .≈ Tuple(-v))
end

@testset "Gather and scatter function" begin
for (arr, VT) in [(arri32, V8I32), (arrf64, V4F64)]
arr .= 1:length(arr)
Expand Down

0 comments on commit bdfd585

Please sign in to comment.