Skip to content

Commit

Permalink
Make Adjoint/Transpose behave like typical constructors.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sacha0 committed Jan 9, 2018
1 parent 29dc306 commit e0c7d41
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 70 deletions.
104 changes: 50 additions & 54 deletions base/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,42 @@ import Base: length, size, axes, IndexStyle, getindex, setindex!, parent, vec, c
struct Adjoint{T,S} <: AbstractMatrix{T}
parent::S
function Adjoint{T,S}(A::S) where {T,S}
checkeltype(Adjoint, T, eltype(A))
checkeltype_adjoint(T, eltype(A))
new(A)
end
end
struct Transpose{T,S} <: AbstractMatrix{T}
parent::S
function Transpose{T,S}(A::S) where {T,S}
checkeltype(Transpose, T, eltype(A))
checkeltype_transpose(T, eltype(A))
new(A)
end
end

function checkeltype(::Type{Transform}, ::Type{ResultEltype}, ::Type{ParentEltype}) where {Transform, ResultEltype, ParentEltype}
if ResultEltype !== transformtype(Transform, ParentEltype)
error(string("Element type mismatch. Tried to create an `$Transform{$ResultEltype}` ",
"from an object with eltype `$ParentEltype`, but the element type of the ",
"`$Transform` of an object with eltype `$ParentEltype` must be ",
"`$(transformtype(Transform, ParentEltype))`"))
end
function checkeltype_adjoint(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype}
ResultEltype === Base.promote_op(adjoint, ParentEltype) || error(string(
"Element type mismatch. Tried to create an `Adjoint{$ResultEltype}` ",
"from an object with eltype `$ParentEltype`, but the element type of ",
"the adjoint of an object with eltype `$ParentEltype` must be ",
"`$(Base.promote_op(adjoint, ParentEltype))`."))
return nothing
end
function transformtype(::Type{O}, ::Type{S}) where {O,S}
# similar to promote_op(::Any, ::Type)
@_inline_meta
T = _return_type(O, Tuple{_default_type(S)})
_isleaftype(S) && return _isleaftype(T) ? T : Any
return typejoin(S, T)
function checkeltype_transpose(::Type{ResultEltype}, ::Type{ParentEltype}) where {ResultEltype,ParentEltype}
ResultEltype === Base.promote_op(transpose, ParentEltype) || error(string(
"Element type mismatch. Tried to create a `Transpose{$ResultEltype}` ",
"from an object with eltype `$ParentEltype`, but the element type of ",
"the transpose of an object with eltype `$ParentEltype` must be ",
"`$(Base.promote_op(transpose, ParentEltype))`."))
return nothing
end

# basic outer constructors
Adjoint(A) = Adjoint{transformtype(Adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{transformtype(Transpose,eltype(A)),typeof(A)}(A)

# numbers are the end of the line
Adjoint(x::Number) = adjoint(x)
Transpose(x::Number) = transpose(x)
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)

# unwrapping constructors
Adjoint(A::Adjoint) = A.parent
Transpose(A::Transpose) = A.parent
# no-op constructors for already-wrapped objects
Adjoint(A::Adjoint) = A
Transpose(A::Transpose) = A

# wrapping lowercase quasi-constructors
adjoint(A::AbstractVecOrMat) = Adjoint(A)
Expand Down Expand Up @@ -80,6 +76,7 @@ julia> transpose(A)
```
"""
transpose(A::AbstractVecOrMat) = Transpose(A)

# unwrapping lowercase quasi-constructors
adjoint(A::Adjoint) = A.parent
transpose(A::Transpose) = A.parent
Expand All @@ -95,10 +92,8 @@ const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector}
const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix}

# for internal use below
wrappertype(A::Adjoint) = Adjoint
wrappertype(A::Transpose) = Transpose
wrappertype(::Type{<:Adjoint}) = Adjoint
wrappertype(::Type{<:Transpose}) = Transpose
wrapperop(A::Adjoint) = adjoint
wrapperop(A::Transpose) = transpose

# AbstractArray interface, basic definitions
length(A::AdjOrTrans) = length(A.parent)
Expand All @@ -108,22 +103,22 @@ axes(v::AdjOrTransAbsVec) = (Base.OneTo(1), axes(v.parent)...)
axes(A::AdjOrTransAbsMat) = reverse(axes(A.parent))
IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear()
IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian()
@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrappertype(v)(v.parent[i])
@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrappertype(A)(A.parent[j, i])
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrappertype(v)(x), i); v)
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrappertype(A)(x), j, i); A)
@propagate_inbounds getindex(v::AdjOrTransAbsVec, i::Int) = wrapperop(v)(v.parent[i])
@propagate_inbounds getindex(A::AdjOrTransAbsMat, i::Int, j::Int) = wrapperop(A)(A.parent[j, i])
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i); v)
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A)
# AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrappertype(v)(v.parent[is])
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrappertype(v)(v.parent[:])
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is])
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:])

# conversion of underlying storage
convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent))
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))

# for vectors, the semantics of the wrapped and unwrapped types differ
# so attempt to maintain both the parent and wrapper type insofar as possible
similar(A::AdjOrTransAbsVec) = wrappertype(A)(similar(A.parent))
similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrappertype(A)(similar(A.parent, transformtype(wrappertype(A), T)))
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))
similar(A::AdjOrTransAbsVec, ::Type{T}) where {T} = wrapperop(A)(similar(A.parent, Base.promote_op(wrapperop(A), T)))
# for matrices, the semantics of the wrapped and unwrapped types are generally the same
# and as you are allocating with similar anyway, you might as well get something unwrapped
similar(A::AdjOrTrans) = similar(A.parent, eltype(A), size(A))
Expand All @@ -142,30 +137,31 @@ isless(A::AdjOrTransAbsVec, B::AdjOrTransAbsVec) = isless(parent(A), parent(B))
# to retain the associated semantics post-concatenation
hcat(avs::Union{Number,AdjointAbsVec}...) = _adjoint_hcat(avs...)
hcat(tvs::Union{Number,TransposeAbsVec}...) = _transpose_hcat(tvs...)
_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = Adjoint(vcat(map(Adjoint, avs)...))
_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = Transpose(vcat(map(Transpose, tvs)...))
typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = Adjoint(typed_vcat(T, map(Adjoint, avs)...))
typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = Transpose(typed_vcat(T, map(Transpose, tvs)...))
_adjoint_hcat(avs::Union{Number,AdjointAbsVec}...) = adjoint(vcat(map(adjoint, avs)...))
_transpose_hcat(tvs::Union{Number,TransposeAbsVec}...) = transpose(vcat(map(transpose, tvs)...))
typed_hcat(::Type{T}, avs::Union{Number,AdjointAbsVec}...) where {T} = adjoint(typed_vcat(T, map(adjoint, avs)...))
typed_hcat(::Type{T}, tvs::Union{Number,TransposeAbsVec}...) where {T} = transpose(typed_vcat(T, map(transpose, tvs)...))
# otherwise-redundant definitions necessary to prevent hitting the concat methods in sparse/sparsevector.jl
hcat(avs::Adjoint{<:Any,<:Vector}...) = _adjoint_hcat(avs...)
hcat(tvs::Transpose{<:Any,<:Vector}...) = _transpose_hcat(tvs...)
hcat(avs::Adjoint{T,Vector{T}}...) where {T} = _adjoint_hcat(avs...)
hcat(tvs::Transpose{T,Vector{T}}...) where {T} = _transpose_hcat(tvs...)
# TODO unify and allow mixed combinations


### higher order functions
# preserve Adjoint/Transpose wrapper around vectors
# to retain the associated semantics post-map/broadcast
#
# note that the caller's operation f operates in the domain of the wrapped vectors' entries.
# hence the Adjoint->f->Adjoint shenanigans applied to the parent vectors' entries.
map(f, avs::AdjointAbsVec...) = Adjoint(map((xs...) -> Adjoint(f(Adjoint.(xs)...)), parent.(avs)...))
map(f, tvs::TransposeAbsVec...) = Transpose(map((xs...) -> Transpose(f(Transpose.(xs)...)), parent.(tvs)...))
# hence the adjoint->f->adjoint shenanigans applied to the parent vectors' entries.
map(f, avs::AdjointAbsVec...) = adjoint(map((xs...) -> adjoint(f(adjoint.(xs)...)), parent.(avs)...))
map(f, tvs::TransposeAbsVec...) = transpose(map((xs...) -> transpose(f(transpose.(xs)...)), parent.(tvs)...))
quasiparentt(x) = parent(x); quasiparentt(x::Number) = x # to handle numbers in the defs below
quasiparenta(x) = parent(x); quasiparenta(x::Number) = conj(x) # to handle numbers in the defs below
broadcast(f, avs::Union{Number,AdjointAbsVec}...) = Adjoint(broadcast((xs...) -> Adjoint(f(Adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = Transpose(broadcast((xs...) -> Transpose(f(Transpose.(xs)...)), quasiparentt.(tvs)...))

broadcast(f, avs::Union{Number,AdjointAbsVec}...) = adjoint(broadcast((xs...) -> adjoint(f(adjoint.(xs)...)), quasiparenta.(avs)...))
broadcast(f, tvs::Union{Number,TransposeAbsVec}...) = transpose(broadcast((xs...) -> transpose(f(transpose.(xs)...)), quasiparentt.(tvs)...))
# TODO unify and allow mixed combinations

### linear algebra

Expand All @@ -186,11 +182,11 @@ end
*(u::TransposeAbsVec, v::TransposeAbsVec) = throw(MethodError(*, (u, v)))

# Adjoint/Transpose-vector * matrix
*(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) * u.parent)
*(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) * u.parent)
*(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) * u.parent)
*(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) * u.parent)
# Adjoint/Transpose-vector * Adjoint/Transpose-matrix
*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Adjoint(A.parent * u.parent)
*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Transpose(A.parent * u.parent)
*(u::AdjointAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = adjoint(A.parent * u.parent)
*(u::TransposeAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = transpose(A.parent * u.parent)


## pseudoinversion
Expand All @@ -203,10 +199,10 @@ pinv(v::TransposeAbsVec, tol::Real = 0) = pinv(conj(v.parent)).parent


## right-division \
/(u::AdjointAbsVec, A::AbstractMatrix) = Adjoint(Adjoint(A) \ u.parent)
/(u::TransposeAbsVec, A::AbstractMatrix) = Transpose(Transpose(A) \ u.parent)
/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = Adjoint(conj(A.parent) \ u.parent) # technically should be Adjoint(copy(Adjoint(copy(A))) \ u.parent)
/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = Transpose(conj(A.parent) \ u.parent) # technically should be Transpose(copy(Transpose(copy(A))) \ u.parent)
/(u::AdjointAbsVec, A::AbstractMatrix) = adjoint(adjoint(A) \ u.parent)
/(u::TransposeAbsVec, A::AbstractMatrix) = transpose(transpose(A) \ u.parent)
/(u::AdjointAbsVec, A::Transpose{<:Any,<:AbstractMatrix}) = adjoint(conj(A.parent) \ u.parent) # technically should be adjoint(copy(adjoint(copy(A))) \ u.parent)
/(u::TransposeAbsVec, A::Adjoint{<:Any,<:AbstractMatrix}) = transpose(conj(A.parent) \ u.parent) # technically should be transpose(copy(transpose(copy(A))) \ u.parent)

# dismabiguation methods
*(A::AdjointAbsVec, B::Transpose{<:Any,<:AbstractMatrix}) = A * copy(B)
Expand Down
32 changes: 16 additions & 16 deletions test/linalg/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,12 @@ end
# the tests for the inner constructors exercise abstract scalar and concrete array eltype, forgoing here
end

@testset "Adjoint and Transpose of Numbers" begin
@test Adjoint(1) == 1
@test Adjoint(1.0) == 1.0
@test Adjoint(1im) == -1im
@test Adjoint(1.0im) == -1.0im
@test Transpose(1) == 1
@test Transpose(1.0) == 1.0
@test Transpose(1im) == 1im
@test Transpose(1.0im) == 1.0im
end

@testset "Adjoint and Transpose unwrapping" begin
@testset "Adjoint and Transpose no-op on already-wrapped objects" begin
intvec, intmat = [1, 2], [1 2; 3 4]
@test Adjoint(Adjoint(intvec)) === intvec
@test Adjoint(Adjoint(intmat)) === intmat
@test Transpose(Transpose(intvec)) === intvec
@test Transpose(Transpose(intmat)) === intmat
@test (A = Adjoint(intvec); Adjoint(A) === A)
@test (A = Adjoint(intmat); Adjoint(A) === A)
@test (A = Transpose(intvec); Transpose(A) === A)
@test (A = Transpose(intmat); Transpose(A) === A)
end

@testset "Adjoint and Transpose basic AbstractArray functionality" begin
Expand Down Expand Up @@ -441,6 +430,17 @@ end
end
end

@testset "adjoint and transpose of Numbers" begin
@test adjoint(1) == 1
@test adjoint(1.0) == 1.0
@test adjoint(1im) == -1im
@test adjoint(1.0im) == -1.0im
@test transpose(1) == 1
@test transpose(1.0) == 1.0
@test transpose(1im) == 1im
@test transpose(1.0im) == 1.0im
end

@testset "adjoint!(a, b) return a" begin
a = fill(1.0+im, 5)
b = fill(1.0+im, 1, 5)
Expand Down

0 comments on commit e0c7d41

Please sign in to comment.