Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dropdims(f, args..; dims, kwargs..) for reductions to drop dims #33130

Closed
wants to merge 10 commits into from
4 changes: 2 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ New library functions
* The `tempname` function now takes a `cleanup::Bool` keyword argument defaulting to `true`, which causes the process to try to ensure that any file or directory at the path returned by `tempname` is deleted upon process exit ([#33090]).
* The `readdir` function now takes a `join::Bool` keyword argument defaulting to `false`, which when set causes `readdir` to join its directory argument with each listed name ([#33113]).


Standard library changes
------------------------

* The methods of `mktemp` and `mktempdir` which take a function body to pass temporary paths to no longer throw errors if the path is already deleted when the function body returns ([#33091]).

#### Libdl
* A new `dropdims(f, args...; dims, kwargs...)` method computes the reduction `f` over the region described by `dims` and then drops those dimensions from the result ([#33130]).
nickrobinson251 marked this conversation as resolved.
Show resolved Hide resolved

#### Libdl

#### LinearAlgebra

Expand Down
31 changes: 31 additions & 0 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,39 @@ function _dropdims(A::AbstractArray, dims::Dims)
end
reshape(A, d::typeof(_sub(axes(A), dims)))
end

_dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))

"""
dropdims(f, args...; dims, kwargs...)

Compute reduction `f` over dimensions `dims` and drop those dimensions from the result.

nickrobinson251 marked this conversation as resolved.
Show resolved Hide resolved
!!! compat "Julia 1.4"
This method requires at least Julia 1.4.

# Examples
```jldoctest
julia> a = [3.0 2.0 6.0 8.0
6.0 1.0 4.0 2.0
3.0 0.0 7.0 6.0];

julia> dropdims(sum, a, dims=1)
4-element Array{Float64,1}:
12.0
3.0
17.0
16.0

julia> dropdims(sum, abs2, a, dims=2)
3-element Array{Float64,1}:
113.0
57.0
94.0
```
nickrobinson251 marked this conversation as resolved.
Show resolved Hide resolved
"""
dropdims(f, args...; dims, kwargs...) = _dropdims(f(args...; kwargs..., dims=dims), dims)

## Unary operators ##

conj(x::AbstractArray{<:Real}) = x
Expand Down
31 changes: 31 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,37 @@ end
@test_throws ArgumentError dropdims(a, dims=4)
@test_throws ArgumentError dropdims(a, dims=6)

# dropdims with reductions. issue #16606
@test (@inferred(dropdims(sum, a, dims=1)) ==
@inferred(dropdims(sum, a, dims=(1,))) ==
reshape(sum(a, dims=1), (1, 8, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=3)) ==
@inferred(dropdims(sum, a, dims=(3,))) ==
reshape(sum(a, dims=3), (1, 1, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=4)) ==
@inferred(dropdims(sum, a, dims=(4,))) ==
reshape(sum(a, dims=4), (1, 1, 8, 1)))
@test (@inferred(dropdims(sum, a, dims=(1, 5))) ==
dropdims(sum, a, dims=(5, 1)) ==
reshape(sum(a, dims=(5, 1)), (1, 8, 8)))
@test (@inferred(dropdims(sum, a, dims=(1, 2, 5))) ==
dropdims(sum, a, dims=(5, 2, 1)) ==
reshape(sum(a, dims=(5, 2, 1)), (8, 8)))
@test (@inferred(dropdims(sum, abs2, a, dims=1)) ==
@inferred(dropdims(sum, abs2, a, dims=(1,))) ==
reshape(sum(abs2, a, dims=1), (1, 8, 8, 1)))
_sumplus(x; dims, plus) = sum(x; dims=dims) .+ plus # reduction with keywords
@test (@inferred(dropdims(_sumplus, a, dims=4, plus=1)) ==
@inferred(dropdims(_sumplus, a, dims=(4,), plus=1)) ==
reshape(sum(a, dims=4) .+ 1, (1, 1, 8, 1)))
@test_throws UndefKeywordError dropdims(sum, a)
@test_throws UndefKeywordError dropdims(sum, a, 1)
@test_throws ArgumentError dropdims(sum, a, dims=0)
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1))
@test_throws ArgumentError dropdims(sum, a, dims=(1, 2, 1))
@test_throws ArgumentError dropdims(sum, a, dims=(1, 1, 2))
@test_throws ArgumentError dropdims(sum, a, dims=6)

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6;]
Expand Down