Skip to content

Commit

Permalink
Fix Vector'Diagonal (ans Transpose as well) to avoid infinite recursi…
Browse files Browse the repository at this point in the history
…on. (#26924)

Also add optimized methods for x'D*y to avoid allocating temporary vector
  • Loading branch information
andreasnoack authored and JeffBezanson committed May 3, 2018
1 parent fcbfc5d commit fdf1682
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
22 changes: 8 additions & 14 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ end
(*)(D::Diagonal, B::AbstractTriangular) = lmul!(D, copy(B))

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
lmul!(D, copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A))

function rmul!(A::AbstractMatrix, D::Diagonal)
A .= A .* transpose(D.diag)
Expand Down Expand Up @@ -271,14 +271,6 @@ mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix) = out .= A.diag .* in
mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= adjoint.(A.parent.diag) .* in
mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= transpose.(A.parent.diag) .* in

mul!(C::AbstractMatrix, A::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B))


# ambiguities with Symmetric/Hermitian
# RealHermSymComplex[Sym]/[Herm] only include Number; invariant to [c]transpose
*(A::Diagonal, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = A * transB.parent
Expand Down Expand Up @@ -478,8 +470,10 @@ function svdfact(D::Diagonal)
end

# dismabiguation methods: * of Diagonal and Adj/Trans AbsVec
*(A::Diagonal, B::Adjoint{<:Any,<:AbstractVector}) = A * copy(B)
*(A::Diagonal, B::Transpose{<:Any,<:AbstractVector}) = A * copy(B)
*(A::Adjoint{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B
*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
# TODO: these methods will yield row matrices, rather than adjoint/transpose vectors
4 changes: 0 additions & 4 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@ end
# these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case):
*(a::AbstractVector, transB::Transpose{<:Any,<:AbstractMatrix}) =
(B = transB.parent; *(reshape(a,length(a),1), transpose(B)))
*(A::AbstractMatrix, transb::Transpose{<:Any,<:AbstractVector}) =
(b = transb.parent; *(A, transpose(reshape(b,length(b),1))))
*(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix}) =
(B = adjB.parent; *(reshape(a,length(a),1), adjoint(B)))
*(A::AbstractMatrix, adjb::Adjoint{<:Any,<:AbstractVector}) =
(b = adjb.parent; *(A, adjoint(reshape(b,length(b),1))))
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B

mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where {T<:BlasFloat} = gemv!(y, 'N', A, x)
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,4 +433,11 @@ end
@test Diagonal(transpose([1, 2, 3])) == Diagonal([1 2 3])
end

@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin
x = rand(5)
D = Diagonal(rand(5))
@test x'*D*x == (x'*D)*x == (x'*Array(D))*x
@test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x
end

end # module TestDiagonal

0 comments on commit fdf1682

Please sign in to comment.