Skip to content

Commit

Permalink
Use StaticArrays' reduce and mapreduce for iszero and count(f, a)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsshin committed Jul 31, 2017
1 parent da1a371 commit 8bd82c8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, lyap, trace, diag, vecnorm, norm, dot, diagm, diag,
lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef,
sum, diff, prod, count, any, all, minimum,
iszero, sum, diff, prod, count, any, all, minimum,
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
randexp!, normalize, normalize!, read, read!, write

Expand Down
2 changes: 2 additions & 0 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ end
#######################

# These are all similar in Base but not @inline'd
@inline iszero(a::StaticArray{<:Any, T}) where {T} = reduce((x,y) -> x && (y==zero(T)), true, a)
@inline sum(a::StaticArray{<:Any, T}) where {T} = reduce(+, zero(T), a)
@inline sum(f::Base.Callable, a::StaticArray) = mapreduce(f, +, a)
@inline prod(a::StaticArray{<:Any, T}) where {T} = reduce(*, one(T), a)
@inline count(a::StaticArray{<:Any, Bool}) = reduce(+, 0, a)
@inline count(f::Base.Callable, a::StaticArray) = mapreduce(x->f(x)::Bool, +, 0, a)
@inline all(a::StaticArray{<:Any, Bool}) = reduce(&, true, a) # non-branching versions
@inline any(a::StaticArray{<:Any, Bool}) = reduce(|, false, a) # (benchmarking needed)
@inline mean(a::StaticArray) = sum(a) / length(a)
Expand Down
6 changes: 5 additions & 1 deletion test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@

@testset "reduce" begin
v1 = @SVector [2,4,6,8]
vz = @SVector zeros(4)
vb = @SVector [true, false, true, false]
@test reduce(+, v1) === 20
@test reduce(+, 0, v1) === 20
@test iszero(vz)
@test !iszero(vb)
@test sum(v1) === 20
@test sum(abs2, v1) === 120
@test prod(v1) === 384
@test mean(v1) === 5.
@test maximum(v1) === 8
@test minimum(v1) === 2
vb = @SVector [true, false, true, false]
@test count(vb) === 2
@test count(x->x==0, vz) === 4
@test any(vb)
end

Expand Down

0 comments on commit 8bd82c8

Please sign in to comment.