Skip to content

Commit

Permalink
Fix Vector'Diagonal (ans Transpose as well) to avoid infinite recursion.
Browse files Browse the repository at this point in the history
Also add optimized methods for x'D*y to avoid allocating temporary vector
  • Loading branch information
andreasnoack committed May 3, 2018
1 parent 7e2ce0e commit 523da54
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
22 changes: 8 additions & 14 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ end
(*)(D::Diagonal, B::AbstractTriangular) = lmul!(D, copy(B))

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
lmul!(D, copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A))

function rmul!(A::AbstractMatrix, D::Diagonal)
A .= A .* transpose(D.diag)
Expand Down Expand Up @@ -271,14 +271,6 @@ mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix) = out .= A.diag .* in
mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= adjoint.(A.parent.diag) .* in
mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= transpose.(A.parent.diag) .* in

mul!(C::AbstractMatrix, A::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))


# ambiguities with Symmetric/Hermitian
# RealHermSymComplex[Sym]/[Herm] only include Number; invariant to [c]transpose
*(A::Diagonal, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = A * transB.parent
Expand Down Expand Up @@ -478,8 +470,10 @@ function svdfact(D::Diagonal)
end

# dismabiguation methods: * of Diagonal and Adj/Trans AbsVec
*(A::Diagonal, B::Adjoint{<:Any,<:AbstractVector}) = A * copy(B)
*(A::Diagonal, B::Transpose{<:Any,<:AbstractVector}) = A * copy(B)
*(A::Adjoint{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
# TODO: these methods will yield row matrices, rather than adjoint/transpose vectors
4 changes: 0 additions & 4 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@ end
# these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case):
*(a::AbstractVector, transB::Transpose{<:Any,<:AbstractMatrix}) =
(B = transB.parent; *(reshape(a,length(a),1), transpose(B)))
*(A::AbstractMatrix, transb::Transpose{<:Any,<:AbstractVector}) =
(b = transb.parent; *(A, transpose(reshape(b,length(b),1))))
*(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix}) =
(B = adjB.parent; *(reshape(a,length(a),1), adjoint(B)))
*(A::AbstractMatrix, adjb::Adjoint{<:Any,<:AbstractVector}) =
(b = adjb.parent; *(A, adjoint(reshape(b,length(b),1))))
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B

mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where {T<:BlasFloat} = gemv!(y, 'N', A, x)
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,11 @@ end
@test Diagonal(transpose([1, 2, 3])) == Diagonal([1 2 3])
end

@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin
x = rand(5)
D = Diagonal(rand(5))
@test x'*D*x == (x'*D)*x == (x'*Array(D))*x
@test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x
end

end # module TestDiagonal

0 comments on commit 523da54

Please sign in to comment.