Skip to content

Commit

Permalink
findextrema: compute findmin and findmax in single pass
Browse files Browse the repository at this point in the history
This is a simple extension of extant `findmin` and `findmax` methods.
Depending on context (cost of `f`; whether reduction is over dims;
size of array) the speedup increase is somewhere between 1.0-1.6 (no regressions).
Interestingly, I noticed but could not locate a `findextrema`; there
is some [mention](JuliaLang#7327) of it,
but nothing in Base. If it was deemed unworthy, please excuse this
errant PR.
  • Loading branch information
andrewjradcliffe committed Nov 16, 2022
1 parent 5f256e7 commit b0dc058
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 0 deletions.
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ export
findmin,
findmin!,
findmax!,
findextrema,
findextrema!
findnext,
findprev,
match,
Expand Down
51 changes: 51 additions & 0 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,57 @@ julia> findmin([1, 7, 7, NaN])
findmin(itr) = _findmin(itr, :)
_findmin(a, ::Colon) = findmin(identity, a)

"""
findextrema(f, domain) -> ((f(x), index_mn), (f(x), index_mx))
Return the pair of pairs which would be returned by `(findmin(f, domain), findmax(f, domain))`,
but computed in a single pass.
# Examples
```jldoctest
julia> findextrema(identity, 5:9)
((5, 1), (9, 5))
julia> findextrema(-, 1:10)
((-10, 10), (-1, 1))
julia> findextrema(first, [(2, :a), (2, :b), (3, :c)])
((2, 1), (3, 3))
julia> findextrema(cos, 0:π/2:2π)
((-1.0, 3), (1.0, 1))
```
"""
findextrema(f, domain) = _findextrema(f, domain, :)
_findextrema(f, domain, ::Colon) = mapfoldl(((k, v),) -> ((f(v), k), (f(v), k)), _rf_findextrema, pairs(domain))
_rf_findextrema((((fm₁, im₁), (fx₁, ix₁))), (((fm₂, im₂), (fx₂, ix₂)))) =
(isgreater(fm₁, fm₂) ? (fm₂, im₂) : (fm₁, im₁)), (isless(fx₁, fx₂) ? (fx₂, ix₂) : (fx₁, ix₁))

"""
findextrema(itr) -> ((mn, index_mn), (mx, index_mx))
Return the pair of pairs which would be returned by `(findmin(itr), findmax(itr))`,
but computed in a single pass.
See also: [`findmin`](@ref), [`findmax`](@ref)
# Examples
```jldoctest
julia> findextrema([8, 0.1, -9, pi])
((-9.0, 3), (8.0, 1))
julia> findextrema([1, 7, 7, 6])
((1, 1), (7, 2))
julia> findextrema([1, 7, 7, NaN])
((NaN, 4), (NaN, 4))
```
"""
findextrema(itr) = _findextrema(itr, :)
_findextrema(a, ::Colon) = findextrema(identity, a)

"""
argmax(f, domain)
Expand Down
147 changes: 147 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,153 @@ end

reducedim1(R, A) = length(axes1(R)) == 1


import Base: promote_op, reduced_indices0, reduced_indices, _findminmax_inittype, check_reducedims, safe_tail, reducedim1, _realtype, axes1, promote_union, isgreater

function findextrema!(f, op_mn, op_mx, Rval_mn, Rind_mn, Rval_mx, Rind_mx, A::AbstractArray{T,N}) where {T,N}
(isempty(Rval_mn) || isempty(Rval_mx) || isempty(A)) && return ((Rval_mn, Rind_mn), (Rval_mx, Rind_mx))
lsiz_mn = check_reducedims(Rval_mn, A)
lsiz_mx = check_reducedims(Rval_mx, A)
for i = 1:N
axes(Rval_mn, i) == axes(Rind_mn, i) == axes(Rval_mx, i) == axes(Rind_mx, i) || throw(DimensionMismatch("Find-reduction: outputs must have the same indices"))
end
# Same as findminmax! implementation
indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval_mn))
keep, Idefault = Broadcast.shapeindexer(indsRt)
ks = keys(A)
y = iterate(ks)
zi = zero(eltype(ks))
if reducedim1(Rval_mn, A)
i1 = first(axes1(Rval_mn))
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
tmpRv_mn = Rval_mn[i1,IR]
tmpRi_mn = Rind_mn[i1,IR]
tmpRv_mx = Rval_mx[i1,IR]
tmpRi_mx = Rind_mx[i1,IR]
for i in axes(A,1)
k, kss = y::Tuple
tmpAv = f(A[i,IA])
if tmpRi_mn == zi || op_mn(tmpRv_mn, tmpAv)
tmpRv_mn = tmpAv
tmpRi_mn = k
end
if tmpRi_mx == zi || op_mx(tmpRv_mx, tmpAv)
tmpRv_mx = tmpAv
tmpRi_mx = k
end
y = iterate(ks, kss)
end
Rval_mn[i1,IR] = tmpRv_mn
Rind_mn[i1,IR] = tmpRi_mn
Rval_mx[i1,IR] = tmpRv_mx
Rind_mx[i1,IR] = tmpRi_mx
end
else
@inbounds for IA in CartesianIndices(indsAt)
IR = Broadcast.newindex(IA, keep, Idefault)
for i in axes(A, 1)
k, kss = y::Tuple
tmpAv = f(A[i,IA])
tmpRv_mn = Rval_mn[i,IR]
tmpRi_mn = Rind_mn[i,IR]
tmpRv_mx = Rval_mx[i,IR]
tmpRi_mx = Rind_mx[i,IR]
if tmpRi_mn == zi || op_mn(tmpRv_mn, tmpAv)
Rval_mn[i,IR] = tmpAv
Rind_mn[i,IR] = k
end
if tmpRi_mx == zi || op_mx(tmpRv_mx, tmpAv)
Rval_mx[i,IR] = tmpAv
Rind_mx[i,IR] = k
end
y = iterate(ks, kss)
end
end
end
((Rval_mn, Rind_mn), (Rval_mx, Rind_mx))
end

"""
findextrema!(rval_mn, rind_mn, rval_mx, rind_mx, A) -> ((minval, index), (maxval, index))
Find the minimum and maximum of `A` and the respective linear index along singleton
dimensions, storing the results in `((rval_mn , rind_mn), (rval_mn , rind_mn))`,
equivalent to `(findmin!(rval_mn, rind_mn, A), findmax!(rval_mx, rind_mx, A))`
but computed in a single pass.
"""
function findextrema!(rval_mn::AbstractArray, rind_mn::AbstractArray, rval_mx::AbstractArray, rind_mx::AbstractArray, A::AbstractArray; init::Bool=true)
init && !isempty(A) && (fill!(rval_mn, first(A)); fill!(rval_mx, first(A)))
Ti = eltype(keys(A))
findextrema!(identity, isgreater, isless, rval_mn, fill!(rind_mn,zero(Ti)), rval_mx, fill!(rind_mx,zero(Ti)), A)
end

"""
findextrema(A; dims) -> ((minval, index), (maxval, index))
For an array input, returns the value and index of the minimum and maximum over the
given dimensions. Equivalent to `(findmin(A; dims), findmax(A; dims))`, but computed
in a single pass.
`NaN` is treated as greater than all other values except `missing`.
# Examples
```jldoctest
julia> A = [1.0 2; 3 4]
2×2 Matrix{Float64}:
1.0 2.0
3.0 4.0
julia> findextrema(A, dims=1) == (findmin(A, dims=1), findmax(A, dims=1))
true
julia> findextrema(A, dims=2) == (findmin(A, dims=2), findmax(A, dims=2))
true
```
"""
findextrema(A::AbstractArray; dims=:) = _findextrema(A, dims)
_findextrema(A, dims) = _findextrema(identity, A, dims)

"""
findextrema(f, A; dims) -> ((f(x), index_mn), (f(x), index_mx))
For an array input, returns the value in the codomain and index of the corresponding value
which minimize and maximize `f` over the given dimensions. Equivalent to
`(findmin(f, A; dims), findmax(f, A; dims))`, but computed in a single pass.
# Examples
```jldoctest
julia> A = [-1.0 1; -0.5 2]
2×2 Matrix{Float64}:
-1.0 1.0
-0.5 2.0
julia> findextrema(abs2, A, dims=1) == (findmin(abs2, A, dims=1), findmax(abs2, A, dims=1))
true
julia> findextrema(abs2, A, dims=2) == (findmin(abs2, A, dims=2), findmax(abs2, A, dims=2))
true
```
"""
findextrema(f, A::AbstractArray; dims=:) = _findextrema(f, A, dims)

function _findextrema(f, A, region)
ri = reduced_indices0(A, region)
if isempty(A)
if prod(map(length, reduced_indices(A, region))) != 0
throw(ArgumentError("collection slices must be non-empty"))
end
Tr = promote_op(f, eltype(A))
Ti = eltype(keys(A))
(similar(A, Tr, ri), zeros(Ti, ri)), (similar(A, Tr, ri), zeros(Ti, ri))
else
fA = f(first(A))
Tr = _findminmax_inittype(f, A)
Ti = eltype(keys(A))
findextrema!(f, isgreater, isless, fill!(similar(A, Tr, ri), fA), zeros(Ti, ri),
fill!(similar(A, Tr, ri), fA), zeros(Ti, ri), A)
end
end

"""
argmin(A; dims) -> indices
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ Base.argmax
Base.argmin
Base.findmax
Base.findmin
Base.findextrema
Base.findmax!
Base.findmin!
Base.findextrema!
Base.sum
Base.sum!
Base.prod
Expand Down
11 changes: 11 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,17 @@ end
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
end

# findextrema
@testset "findextrema(f, domain)" begin
@test findextrema(-, 1:10) == ((-10, 10), (-1, 1))
@test findextrema(identity, [1, 2, 3, missing]) === ((missing, 4), (missing, 4))
@test findextrema(identity, [1, NaN, 3, missing]) === ((missing, 4), (missing, 4))
@test findextrema(identity, [1, missing, NaN, 3]) === ((missing, 2), (missing, 2))
@test findextrema(identity, [1, NaN, 3]) === ((NaN, 2), (NaN, 2))
@test findextrema(identity, [1, 3, NaN]) === ((NaN, 3), (NaN, 3))
@test findextrema(cos, 0:π/2:2π) == ((-1.0, 3), (1.0, 1))
end

# any & all

@test @inferred any([]) == false
Expand Down
Loading

0 comments on commit b0dc058

Please sign in to comment.