Skip to content

Commit

Permalink
faster reductions of Transpose, Adjoint, PermutedDimsArray (#39513)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Apr 19, 2021
1 parent 28a3312 commit 58bde18
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 2 deletions.
5 changes: 3 additions & 2 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ include("methodshow.jl")
include("cartesian.jl")
using .Cartesian
include("multidimensional.jl")
include("permuteddimsarray.jl")
using .PermutedDimsArrays

include("broadcast.jl")
using .Broadcast
Expand Down Expand Up @@ -293,6 +291,9 @@ end
include("reducedim.jl") # macros in this file relies on string.jl
include("accumulate.jl")

include("permuteddimsarray.jl")
using .PermutedDimsArrays

# basic data structures
include("ordering.jl")
using .Order
Expand Down
10 changes: 10 additions & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ end
P
end

function Base._mapreduce_dim(f, op, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon)
Base._mapreduce_dim(f, op, init, parent(A), dims)
end

function Base.mapreducedim!(f, op, B::AbstractArray{T,N}, A::PermutedDimsArray{T,N,perm,iperm}) where {T,N,perm,iperm}
C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
B
end

function Base.showarg(io::IO, A::PermutedDimsArray{T,N,perm}, toplevel) where {T,N,perm}
print(io, "PermutedDimsArray(")
Base.showarg(io, parent(A), false)
Expand Down
21 changes: 21 additions & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ end
# some aliases for internal convenience use
const AdjOrTrans{T,S} = Union{Adjoint{T,S},Transpose{T,S}} where {T,S}
const AdjointAbsVec{T} = Adjoint{T,<:AbstractVector}
const AdjointAbsMat{T} = Adjoint{T,<:AbstractMatrix}
const TransposeAbsVec{T} = Transpose{T,<:AbstractVector}
const TransposeAbsMat{T} = Transpose{T,<:AbstractMatrix}
const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector}
const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix}

Expand Down Expand Up @@ -275,6 +277,25 @@ Broadcast.broadcast_preserving_zero_d(f, avs::Union{Number,AdjointAbsVec}...) =
Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# TODO unify and allow mixed combinations with a broadcast style


### reductions
# faster to sum the Array than to work through the wrapper
Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Transpose, dims::Colon) =
transpose(Base._mapreduce_dim(_sandwich(transpose, f), _sandwich(transpose, op), init, parent(A), dims))
Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Adjoint, dims::Colon) =
adjoint(Base._mapreduce_dim(_sandwich(adjoint, f), _sandwich(adjoint, op), init, parent(A), dims))
# sum(A'; dims)
Base.mapreducedim!(f, op, B::AbstractArray, A::TransposeAbsMat) =
transpose(Base.mapreducedim!(_sandwich(transpose, f), _sandwich(transpose, op), transpose(B), parent(A)))
Base.mapreducedim!(f, op, B::AbstractArray, A::AdjointAbsMat) =
adjoint(Base.mapreducedim!(_sandwich(adjoint, f), _sandwich(adjoint, op), adjoint(B), parent(A)))

_sandwich(adj::Function, fun) = (xs...,) -> adj(fun(map(adj, xs)...))
for fun in [:identity, :add_sum, :mul_prod] #, :max, :min]
@eval _sandwich(::Function, ::typeof(Base.$fun)) = Base.$fun
end


### linear algebra

(-)(A::Adjoint) = Adjoint( -A.parent)
Expand Down
20 changes: 20 additions & 0 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,24 @@ end
@test transpose(Int[]) * Int[] == 0
end

@testset "reductions: $adjtrans" for adjtrans in [transpose, adjoint]
mat = rand(ComplexF64, 3,5)
@test sum(adjtrans(mat)) sum(collect(adjtrans(mat)))
@test sum(adjtrans(mat), dims=1) sum(collect(adjtrans(mat)), dims=1)
@test sum(adjtrans(mat), dims=(1,2)) sum(collect(adjtrans(mat)), dims=(1,2))

@test sum(imag, adjtrans(mat)) sum(imag, collect(adjtrans(mat)))
@test sum(imag, adjtrans(mat), dims=1) sum(imag, collect(adjtrans(mat)), dims=1)

mat = [rand(ComplexF64,2,2) for _ in 1:3, _ in 1:5]
@test sum(adjtrans(mat)) sum(collect(adjtrans(mat)))
@test sum(adjtrans(mat), dims=1) sum(collect(adjtrans(mat)), dims=1)
@test sum(adjtrans(mat), dims=(1,2)) sum(collect(adjtrans(mat)), dims=(1,2))

@test sum(imag, adjtrans(mat)) sum(imag, collect(adjtrans(mat)))
@test sum(x -> x[1,2], adjtrans(mat)) sum(x -> x[1,2], collect(adjtrans(mat)))
@test sum(imag, adjtrans(mat), dims=1) sum(imag, collect(adjtrans(mat)), dims=1)
@test sum(x -> x[1,2], adjtrans(mat), dims=1) sum(x -> x[1,2], collect(adjtrans(mat)), dims=1)
end

end # module TestAdjointTranspose
4 changes: 4 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,10 @@ end
perm = randperm(4)
@test isequal(A,permutedims(permutedims(A,perm),invperm(perm)))
@test isequal(A,permutedims(permutedims(A,invperm(perm)),perm))

@test sum(permutedims(A,perm)) sum(PermutedDimsArray(A,perm))
@test sum(permutedims(A,perm), dims=2) sum(PermutedDimsArray(A,perm), dims=2)
@test sum(permutedims(A,perm), dims=(2,4)) sum(PermutedDimsArray(A,perm), dims=(2,4))
end

m = [1 2; 3 4]
Expand Down

0 comments on commit 58bde18

Please sign in to comment.