Skip to content

Commit

Permalink
Allow dropdims with reduction to take mutliple args and kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
nickrobinson251 committed Aug 31, 2019
1 parent 012d910 commit 4d84d00
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
26 changes: 23 additions & 3 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ end
_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))

"""
squeeze(f, A, dims)
dropdims(f, args...; dims, kwargs...)
Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result.
Compute reduction `f` over dimensions `dims` and drop those dimensions from the result.
# Examples
```jldoctest
julia> a = [3.0 2.0 6.0 8.0
6.0 1.0 4.0 2.0
3.0 0.0 7.0 6.0];
julia> dropdims(sum, a, dims=1)
4-element Array{Float64,1}:
12.0
3.0
17.0
16.0
julia> dropdims(sum, abs2, a, dims=2)
3-element Array{Float64,1}:
113.0
57.0
94.0
```
"""
squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims)
dropdims(f, args...; dims, kwargs...) = dropdims(f(args...; kwargs..., dims=dims), dims=dims)

## Unary operators ##

Expand Down
40 changes: 30 additions & 10 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,36 @@ end
@test_throws ArgumentError dropdims(a, dims=4)
@test_throws ArgumentError dropdims(a, dims=6)

@test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1))
@test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1))
@test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1))
@test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8))
@test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8))
@test_throws ArgumentError squeeze(sum, a, 0)
@test_throws ArgumentError squeeze(sum, a, (1, 1))
@test_throws ArgumentError squeeze(sum, a, (1, 2, 1))
@test_throws ArgumentError squeeze(sum, a, (1, 1, 2))
@test_throws ArgumentError squeeze(sum, a, 6)
# dropdims with reductions. issue #16606
@test (@inferred(dropdims(sum, a, dims=1)) ==
@inferred(dropdims(sum, a, dims=(1,))) ==
reshape(sum(a, dims=1), (1, 8, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=3)) ==
@inferred(dropdims(sum, a, dims=(3,))) ==
reshape(sum(a, dims=3), (1, 1, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=4)) ==
@inferred(dropdims(sum, a, dims=(4,))) ==
reshape(sum(a, dims=4), (1, 1, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=(1, 5))) ==
dropdims(sum, a, dims=(5, 1)) ==
reshape(sum(a, (5, 1)), (1, 8, 8)))
@test (@inferred(dropdims(sum, a, dims=(1, 2, 5))) ==
dropdims(sum, a, dims=(5, 2, 1)) ==
reshape(sum(a, dims=(5, 2, 1)), (8, 8)))
@test (@inferred(dropdims(sum, abs2, a, dims=1)) ==
@inferred(dropdims(sum, abs2, a, dims=(1,))) ==
reshape(sum(a, dims=1), (1, 8, 8, 1)))
_sumplus(x; dims, plus) = sum(x; dims=dims) .+ plus # reduction with keywords
@test (@inferred(dropdims(_sumplus, a, dims=4, plus=1)) ==
@inferred(dropdims(_sumplus, a, dims=(4,), plus=1)) ==
reshape(sum(a, dims=4) .+ 1, (1, 1, 8, 1)))
@test_throws UndefKeywordError dropdims(sum, a)
@test_throws UndefKeywordError dropdims(sum, a, 1)
@test_throws ArgumentError dropdims(sum, a, dims=0)
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1))
@test_throws ArgumentError dropdims(sum, a, dims=(1, 2, 1))
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1, 2))
@test_throws ArgumentError dropdims(sum, a, dims=6)

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
Expand Down

0 comments on commit 4d84d00

Please sign in to comment.