From 7f364f4dcc329e34314199372cbfb9b5d55a0fcf Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Wed, 8 May 2019 16:44:09 +0200 Subject: [PATCH] Cleanup --- base/reducedim.jl | 13 ++++++++++--- stdlib/Statistics/src/Statistics.jl | 4 ++-- test/reduce.jl | 4 ++++ test/reducedim.jl | 3 +++ 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/base/reducedim.jl b/base/reducedim.jl index 97b70b6d6fcd4..27c5e05e40c46 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -724,9 +724,12 @@ sum!(f::Function, r::AbstractArray, A::AbstractArray; _sum!(f, r, A, weights; init=init) _sum!(f, r::AbstractArray, A::AbstractArray, ::Nothing; init::Bool=true) = mapreducedim!(f, add_sum, initarray!(r, add_sum, init, A), A) -_sum(f, A, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims) -_sum(A::AbstractArray, dims, w::AbstractArray) = +_sum(A::AbstractArray, dims, weights) = _sum(identity, A, dims, weights) +_sum(f, A::AbstractArray, dims, ::Nothing) = mapreduce(f, add_sum, A, dims=dims) +_sum(::typeof(identity), A::AbstractArray, dims, w::AbstractArray) = _sum!(identity, reducedim_init(t -> t*zero(eltype(w)), add_sum, A, dims), A, w) +_sum(f, A::AbstractArray, dims, w::AbstractArray) = + throw(ArgumentError("Passing a function is not supported with `weights`")) # Weighted sum @@ -832,8 +835,11 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) = _wsum_general!(R, A, w, dim, init) -function _sum!(::typeof(identity), R::AbstractArray, A::AbstractArray{T,N}, w::AbstractVector; +function _sum!(f, R::AbstractArray, A::AbstractArray{T,N}, w::AbstractArray; init::Bool=true) where {T,N} + f === identity || throw(ArgumentError("Passing a function is not supported with `weights`")) + w isa AbstractVector || throw(ArgumentError("Only vector `weights` are supported")) + check_reducedims(R,A) reddims = size(R) .!= size(A) dim = something(findfirst(reddims), ndims(R)+1) @@ -849,6 +855,7 @@ function _sum!(::typeof(identity), R::AbstractArray, A::AbstractArray{T,N}, w::A _wsum!(R, A, w, dim, init) end + ##### findmin & findmax ##### # The initial values of Rval are not used if the corresponding indices in Rind are 0. # diff --git a/stdlib/Statistics/src/Statistics.jl b/stdlib/Statistics/src/Statistics.jl index 044f5aba67dc2..f3756e1b0a01d 100644 --- a/stdlib/Statistics/src/Statistics.jl +++ b/stdlib/Statistics/src/Statistics.jl @@ -195,14 +195,14 @@ function _mean(r::AbstractRange{<:Real}, dims::Colon, weights::Nothing) end _mean(A::AbstractArray, dims, weights::Nothing) = - _mean!(Base.reducedim_init(t -> t/2, +, A, dims), A, nothing) + _mean!(Base.reducedim_init(t -> t/2, Base.add_sum, A, dims), A, nothing) _mean(A::AbstractArray, dims::Colon, weights::Nothing) = sum(A) / length(A) _mean(A::AbstractArray, dims::Colon, w::AbstractArray) = sum(A, weights=w) / sum(w) _mean(A::AbstractArray, dims, w::AbstractArray) = - _mean!(Base.reducedim_init(t -> (t*zero(eltype(w)))/2, +, A, dims), A, w) + _mean!(Base.reducedim_init(t -> (t*zero(eltype(w)))/2, Base.add_sum, A, dims), A, w) ##### variances ##### diff --git a/test/reduce.jl b/test/reduce.jl index d3ba20c4091ef..9b89e8f6741b6 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -563,4 +563,8 @@ x = [j+7 for j in i] @test typeof(res) == typeof(expected) end end + + @test_throws ArgumentError sum(exp, [1], weights=[1]) + @test_throws ArgumentError sum!(exp, [0 0], [1 2], weights=[1, 10]) + @test_throws ArgumentError sum!([0 0], [1 2], weights=[1 10]) end \ No newline at end of file diff --git a/test/reducedim.jl b/test/reducedim.jl index 38d8289051963..a90c2c1f96afb 100644 --- a/test/reducedim.jl +++ b/test/reducedim.jl @@ -481,4 +481,7 @@ end @test_throws DimensionMismatch sum(a, weights=w, dims=4) end end + + @test_throws ArgumentError sum(exp, [1 2], weights=[1, 10], dims=1) + @test_throws ArgumentError sum([1 2], weights=[1 10], dims=1) end \ No newline at end of file