From 58bde1851c9df1904791fce3c5edb9ef8581ca06 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 19 Apr 2021 05:34:16 -0400 Subject: [PATCH] faster reductions of Transpose, Adjoint, PermutedDimsArray (#39513) --- base/Base.jl | 5 +++-- base/permuteddimsarray.jl | 10 ++++++++++ stdlib/LinearAlgebra/src/adjtrans.jl | 21 +++++++++++++++++++++ stdlib/LinearAlgebra/test/adjtrans.jl | 20 ++++++++++++++++++++ test/arrayops.jl | 4 ++++ 5 files changed, 58 insertions(+), 2 deletions(-) diff --git a/base/Base.jl b/base/Base.jl index 5a4826fe5df26..f92cd4e1c3c08 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -213,8 +213,6 @@ include("methodshow.jl") include("cartesian.jl") using .Cartesian include("multidimensional.jl") -include("permuteddimsarray.jl") -using .PermutedDimsArrays include("broadcast.jl") using .Broadcast @@ -293,6 +291,9 @@ end include("reducedim.jl") # macros in this file relies on string.jl include("accumulate.jl") +include("permuteddimsarray.jl") +using .PermutedDimsArrays + # basic data structures include("ordering.jl") using .Order diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index 429fa67b2a3ab..3fc2a2340efdc 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -253,6 +253,16 @@ end P end +function Base._mapreduce_dim(f, op, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon) + Base._mapreduce_dim(f, op, init, parent(A), dims) +end + +function Base.mapreducedim!(f, op, B::AbstractArray{T,N}, A::PermutedDimsArray{T,N,perm,iperm}) where {T,N,perm,iperm} + C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output + Base.mapreducedim!(f, op, C, parent(A)) + B +end + function Base.showarg(io::IO, A::PermutedDimsArray{T,N,perm}, toplevel) where {T,N,perm} print(io, "PermutedDimsArray(") Base.showarg(io, parent(A), false) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index f3d08abc54b0e..494429a4fca2d 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -185,7 +185,9 @@ end # some aliases for internal convenience use const AdjOrTrans{T,S} = Union{Adjoint{T,S},Transpose{T,S}} where {T,S} const AdjointAbsVec{T} = Adjoint{T,<:AbstractVector} +const AdjointAbsMat{T} = Adjoint{T,<:AbstractMatrix} const TransposeAbsVec{T} = Transpose{T,<:AbstractVector} +const TransposeAbsMat{T} = Transpose{T,<:AbstractMatrix} const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector} const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix} @@ -275,6 +277,25 @@ Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) = Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...)) # TODO unify and allow mixed combinations with a broadcast style + +### reductions +# faster to sum the Array than to work through the wrapper +Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Transpose, dims::Colon) = + transpose(Base._mapreduce_dim(_sandwich(transpose, f), _sandwich(transpose, op), init, parent(A), dims)) +Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Adjoint, dims::Colon) = + adjoint(Base._mapreduce_dim(_sandwich(adjoint, f), _sandwich(adjoint, op), init, parent(A), dims)) +# sum(A'; dims) +Base.mapreducedim!(f, op, B::AbstractArray, A::TransposeAbsMat) = + transpose(Base.mapreducedim!(_sandwich(transpose, f), _sandwich(transpose, op), transpose(B), parent(A))) +Base.mapreducedim!(f, op, B::AbstractArray, A::AdjointAbsMat) = + adjoint(Base.mapreducedim!(_sandwich(adjoint, f), _sandwich(adjoint, op), adjoint(B), parent(A))) + +_sandwich(adj::Function, fun) = (xs...,) -> adj(fun(map(adj, xs)...)) +for fun in [:identity, :add_sum, :mul_prod] #, :max, :min] + @eval _sandwich(::Function, ::typeof(Base.$fun)) = Base.$fun +end + + ### linear algebra (-)(A::Adjoint) = Adjoint( -A.parent) diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index e6f271b2b4650..083c8dfa5cb29 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -576,4 +576,24 @@ end @test transpose(Int[]) * Int[] == 0 end +@testset "reductions: $adjtrans" for adjtrans in [transpose, adjoint] + mat = rand(ComplexF64, 3,5) + @test sum(adjtrans(mat)) ≈ sum(collect(adjtrans(mat))) + @test sum(adjtrans(mat), dims=1) ≈ sum(collect(adjtrans(mat)), dims=1) + @test sum(adjtrans(mat), dims=(1,2)) ≈ sum(collect(adjtrans(mat)), dims=(1,2)) + + @test sum(imag, adjtrans(mat)) ≈ sum(imag, collect(adjtrans(mat))) + @test sum(imag, adjtrans(mat), dims=1) ≈ sum(imag, collect(adjtrans(mat)), dims=1) + + mat = [rand(ComplexF64,2,2) for _ in 1:3, _ in 1:5] + @test sum(adjtrans(mat)) ≈ sum(collect(adjtrans(mat))) + @test sum(adjtrans(mat), dims=1) ≈ sum(collect(adjtrans(mat)), dims=1) + @test sum(adjtrans(mat), dims=(1,2)) ≈ sum(collect(adjtrans(mat)), dims=(1,2)) + + @test sum(imag, adjtrans(mat)) ≈ sum(imag, collect(adjtrans(mat))) + @test sum(x -> x[1,2], adjtrans(mat)) ≈ sum(x -> x[1,2], collect(adjtrans(mat))) + @test sum(imag, adjtrans(mat), dims=1) ≈ sum(imag, collect(adjtrans(mat)), dims=1) + @test sum(x -> x[1,2], adjtrans(mat), dims=1) ≈ sum(x -> x[1,2], collect(adjtrans(mat)), dims=1) +end + end # module TestAdjointTranspose diff --git a/test/arrayops.jl b/test/arrayops.jl index 80a37c9d2f26a..27e366f1ce3cc 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -701,6 +701,10 @@ end perm = randperm(4) @test isequal(A,permutedims(permutedims(A,perm),invperm(perm))) @test isequal(A,permutedims(permutedims(A,invperm(perm)),perm)) + + @test sum(permutedims(A,perm)) ≈ sum(PermutedDimsArray(A,perm)) + @test sum(permutedims(A,perm), dims=2) ≈ sum(PermutedDimsArray(A,perm), dims=2) + @test sum(permutedims(A,perm), dims=(2,4)) ≈ sum(PermutedDimsArray(A,perm), dims=(2,4)) end m = [1 2; 3 4]