Skip to content

Commit

Permalink
Extend min/maximum optimization to much shorter length
Browse files Browse the repository at this point in the history
Update reduce.jl
  • Loading branch information
N5N3 committed Dec 31, 2021
1 parent 07d1f81 commit 19ce317
Showing 1 changed file with 26 additions and 34 deletions.
60 changes: 26 additions & 34 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -625,49 +625,41 @@ isbadzero(op, x) = false
isgoodzero(::typeof(max), x) = isbadzero(min, x)
isgoodzero(::typeof(min), x) = isbadzero(max, x)

function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
# 1. This optimization gives different result from general fallback, if the inputs `f.(A)`
# contains both 'missing' and 'Nan'.
# 2. For Integer cases, general fallback seems faster.
# Based the above reasons, only use this for AbstractFloat cases.
Eltype = _return_type(i -> f(A[i]), Tuple{Int})
function mapreduce_impl(f, op::Union{typeof(max),typeof(min)},
A::AbstractArrayOrBroadcasted, fi::Int, la::Int)
@inline elf(i) = @inbounds f(A[i])
# 1. If `f.(A)` contains both 'missing' and 'Nan', this might return `NaN`.
# 2. For Integer input, general fallback is about 2x faster.
# Thus limit this optimization to AbstractFloat.
Eltype = _return_type(elf, Tuple{Int})
Eltype <: AbstractFloat ||
return invoke(mapreduce_impl,Tuple{Any,Any,AbstractArrayOrBroadcasted,Int,Int},f,op,A,first,last)
a1 = @inbounds A[first]
v1 = mapreduce_first(f, op, a1)
v2 = v3 = v4 = v1
chunk_len = 256
start = first + 1
simdstop = start + chunk_len - 4
while simdstop <= last - 3
# short circuit in case of NaN or missing
v1 == v1 || return v1
v2 == v2 || return v2
v3 == v3 || return v3
v4 == v4 || return v4
@inbounds for i in start:4:simdstop
v1 = _fast(op, v1, f(A[i+0]))
v2 = _fast(op, v2, f(A[i+1]))
v3 = _fast(op, v3, f(A[i+2]))
v4 = _fast(op, v4, f(A[i+3]))
return invoke(mapreduce_impl,Tuple{Any,Any,AbstractArrayOrBroadcasted,Int,Int},f,op,A,fi,la)
v1 = v2 = v3 = v4 = elf(fi)
len = (la - fi) >> 2
i = fi
for I in Iterators.partition(1:len, 64)
for _ in I
v1 = _fast(op, v1, elf(i+=1))
v2 = _fast(op, v2, elf(i+=1))
v3 = _fast(op, v3, elf(i+=1))
v4 = _fast(op, v4, elf(i+=1))
end
checkbounds(A, simdstop+3)
start += chunk_len
simdstop += chunk_len
# short circuit in case of NaN
isnan(v1) && return v1
isnan(v2) && return v2
isnan(v3) && return v3
isnan(v4) && return v4
end
v = op(op(v1,v2),op(v3,v4))
for i in start:last
@inbounds ai = A[i]
v = op(v, f(ai))
for i in i+1:la
v = op(v, elf(i))
end

# enforce correct order of 0.0 and -0.0
# e.g. maximum([0.0, -0.0]) === 0.0
# should hold
if isbadzero(op, v)
for i in first:last
x = @inbounds A[i]
for i in fi:la
x = elf(i)
isgoodzero(op,x) && return x
end
end
Expand Down

0 comments on commit 19ce317

Please sign in to comment.