From d9cbea92934bbcb08f15b365a0906950a180135a Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Tue, 29 Aug 2017 15:11:16 -0400 Subject: [PATCH 1/3] Add squeeze(f, A, dims) for reductions to drop dims This simple definition makes it easier to write reductions that drops the dimensions over which they reduce. Fixes #16606, addresses part of the root issue in #22000. --- base/abstractarraymath.jl | 8 ++++++++ test/arrayops.jl | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index 9c7b098ff0d42..b7acf141bc4e3 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -84,8 +84,16 @@ function _dropdims(A::AbstractArray, dims::Dims) end reshape(A, d::typeof(_sub(axes(A), dims))) end + _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),)) +""" + squeeze(f, A, dims) + +Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result. +""" +squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims) + ## Unary operators ## conj(x::AbstractArray{<:Real}) = x diff --git a/test/arrayops.jl b/test/arrayops.jl index 7a2fa864f543c..65e31028b644e 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -303,6 +303,17 @@ end @test_throws ArgumentError dropdims(a, dims=4) @test_throws ArgumentError dropdims(a, dims=6) + @test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1)) + @test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1)) + @test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8)) + @test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8)) + @test_throws ArgumentError squeeze(sum, a, 0) + @test_throws ArgumentError squeeze(sum, a, (1, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 2, 1)) + @test_throws ArgumentError squeeze(sum, a, (1, 1, 2)) + @test_throws ArgumentError squeeze(sum, a, 6) + sz = (5,8,7) A = reshape(1:prod(sz),sz...) @test A[2:6] == [2:6;] From 012d91074b9d601236cdc15acb18c597df294fb8 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Thu, 12 Oct 2017 15:10:11 -0500 Subject: [PATCH 2/3] Add NEWS --- NEWS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 2ff018ce96ce4..8c84fb1b59be8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -32,8 +32,10 @@ Standard library changes * The methods of `mktemp` and `mktempdir` which take a function body to pass temporary paths to no longer throw errors if the path is already deleted when the function body returns ([#33091]). -#### Libdl +* A new `squeeze(f, A, dims)` method computes the reduction `f` over the region in +`A` described by `dims` and then drops those dimensions from the result ([#23500]). +#### Libdl #### LinearAlgebra From 29945f324d43ee9814a403ce8b9d49dfcb4b44e8 Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Sat, 31 Aug 2019 13:31:20 +0100 Subject: [PATCH 3/3] Allow `dropdims` with reduction to take mutliple args and kwargs --- NEWS.md | 4 +--- base/abstractarraymath.jl | 26 ++++++++++++++++++++++--- test/arrayops.jl | 40 +++++++++++++++++++++++++++++---------- 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/NEWS.md b/NEWS.md index 8c84fb1b59be8..97ceb39b5a93f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -26,14 +26,12 @@ New library functions * The `tempname` function now takes a `cleanup::Bool` keyword argument defaulting to `true`, which causes the process to try to ensure that any file or directory at the path returned by `tempname` is deleted upon process exit ([#33090]). * The `readdir` function now takes a `join::Bool` keyword argument defaulting to `false`, which when set causes `readdir` to join its directory argument with each listed name ([#33113]). - Standard library changes ------------------------ * The methods of `mktemp` and `mktempdir` which take a function body to pass temporary paths to no longer throw errors if the path is already deleted when the function body returns ([#33091]). -* A new `squeeze(f, A, dims)` method computes the reduction `f` over the region in -`A` described by `dims` and then drops those dimensions from the result ([#23500]). +* A new `dropdims(f, args...; dims, kwargs...)` method computes the reduction `f` over the region described by `dims` and then drops those dimensions from the result ([#23500]). #### Libdl diff --git a/base/abstractarraymath.jl b/base/abstractarraymath.jl index b7acf141bc4e3..6885fdf52a08e 100644 --- a/base/abstractarraymath.jl +++ b/base/abstractarraymath.jl @@ -88,11 +88,31 @@ end _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),)) """ - squeeze(f, A, dims) + dropdims(f, args...; dims, kwargs...) -Compute reduction `f` over dimensions `dims` in array `A` and drop those dimensions from the result. +Compute reduction `f` over dimensions `dims` and drop those dimensions from the result. + +# Examples +```jldoctest +julia> a = [3.0 2.0 6.0 8.0 + 6.0 1.0 4.0 2.0 + 3.0 0.0 7.0 6.0]; + +julia> dropdims(sum, a, dims=1) +4-element Array{Float64,1}: + 12.0 + 3.0 + 17.0 + 16.0 + +julia> dropdims(sum, abs2, a, dims=2) +3-element Array{Float64,1}: + 113.0 + 57.0 + 94.0 +``` """ -squeeze(f, A::AbstractArray, dims::Union{Dims, Integer}) = squeeze(f(A, dims), dims) +dropdims(f, args...; dims, kwargs...) = dropdims(f(args...; kwargs..., dims=dims), dims=dims) ## Unary operators ## diff --git a/test/arrayops.jl b/test/arrayops.jl index 65e31028b644e..c4cebe2c31659 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -303,16 +303,36 @@ end @test_throws ArgumentError dropdims(a, dims=4) @test_throws ArgumentError dropdims(a, dims=6) - @test @inferred(squeeze(sum, a, 1)) == @inferred(squeeze(sum, a, (1,))) == reshape(sum(a, 1), (1, 8, 8, 1)) - @test @inferred(squeeze(sum, a, 3)) == @inferred(squeeze(sum, a, (3,))) == reshape(sum(a, 3), (1, 1, 8, 1)) - @test @inferred(squeeze(sum, a, 4)) == @inferred(squeeze(sum, a, (4,))) == reshape(sum(a, 4), (1, 1, 8, 1)) - @test @inferred(squeeze(sum, a, (1, 5))) == squeeze(sum, a, (5, 1)) == reshape(sum(a, (5, 1)), (1, 8, 8)) - @test @inferred(squeeze(sum, a, (1, 2, 5))) == squeeze(sum, a, (5, 2, 1)) == reshape(sum(a, (5, 2, 1)), (8, 8)) - @test_throws ArgumentError squeeze(sum, a, 0) - @test_throws ArgumentError squeeze(sum, a, (1, 1)) - @test_throws ArgumentError squeeze(sum, a, (1, 2, 1)) - @test_throws ArgumentError squeeze(sum, a, (1, 1, 2)) - @test_throws ArgumentError squeeze(sum, a, 6) + # dropdims with reductions. issue #16606 + @test (@inferred(dropdims(sum, a, dims=1)) == + @inferred(dropdims(sum, a, dims=(1,))) == + reshape(sum(a, dims=1), (1, 8, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=3)) == + @inferred(dropdims(sum, a, dims=(3,))) == + reshape(sum(a, dims=3), (1, 1, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=4)) == + @inferred(dropdims(sum, a, dims=(4,))) == + reshape(sum(a, dims=4), (1, 1, 8, 1))) + @test (@inferred(dropdims(sum, a, dims=(1, 5))) == + dropdims(sum, a, dims=(5, 1)) == + reshape(sum(a, (5, 1)), (1, 8, 8))) + @test (@inferred(dropdims(sum, a, dims=(1, 2, 5))) == + dropdims(sum, a, dims=(5, 2, 1)) == + reshape(sum(a, dims=(5, 2, 1)), (8, 8))) + @test (@inferred(dropdims(sum, abs2, a, dims=1)) == + @inferred(dropdims(sum, abs2, a, dims=(1,))) == + reshape(sum(a, dims=1), (1, 8, 8, 1))) + _sumplus(x; dims, plus) = sum(x; dims=dims) .+ plus # reduction with keywords + @test (@inferred(dropdims(_sumplus, a, dims=4, plus=1)) == + @inferred(dropdims(_sumplus, a, dims=(4,), plus=1)) == + reshape(sum(a, dims=4) .+ 1, (1, 1, 8, 1))) + @test_throws UndefKeywordError dropdims(sum, a) + @test_throws UndefKeywordError dropdims(sum, a, 1) + @test_throws ArgumentError dropdims(sum, a, dims=0) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 1)) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 2, 1)) + @test_throws ArgumentError dropdims(sum, a, dims=(1, 1, 2)) + @test_throws ArgumentError dropdims(sum, a, dims=6) sz = (5,8,7) A = reshape(1:prod(sz),sz...)