From 523da544311439ef12c9a77eb83c662a4f8d68c1 Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Fri, 27 Apr 2018 20:58:03 +0200 Subject: [PATCH] Fix Vector'Diagonal (ans Transpose as well) to avoid infinite recursion. Also add optimized methods for x'D*y to avoid allocating temporary vector --- stdlib/LinearAlgebra/src/diagonal.jl | 22 ++++++++-------------- stdlib/LinearAlgebra/src/matmul.jl | 4 ---- stdlib/LinearAlgebra/test/diagonal.jl | 7 +++++++ 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index ed01cadd9f92f..95406701e3048 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -154,9 +154,9 @@ end (*)(D::Diagonal, B::AbstractTriangular) = lmul!(D, copy(B)) (*)(A::AbstractMatrix, D::Diagonal) = - mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D) + rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D) (*)(D::Diagonal, A::AbstractMatrix) = - mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A) + lmul!(D, copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A)) function rmul!(A::AbstractMatrix, D::Diagonal) A .= A .* transpose(D.diag) @@ -271,14 +271,6 @@ mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix) = out .= A.diag .* in mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= adjoint.(A.parent.diag) .* in mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::AbstractMatrix) = out .= transpose.(A.parent.diag) .* in -mul!(C::AbstractMatrix, A::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) -mul!(C::AbstractMatrix, A::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) -mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) -mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) -mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) -mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:AbstractVecOrMat}) = mul!(C, A, copy(B)) - - # ambiguities with Symmetric/Hermitian # RealHermSymComplex[Sym]/[Herm] only include Number; invariant to [c]transpose *(A::Diagonal, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = A * transB.parent @@ -478,8 +470,10 @@ function svdfact(D::Diagonal) end # dismabiguation methods: * of Diagonal and Adj/Trans AbsVec -*(A::Diagonal, B::Adjoint{<:Any,<:AbstractVector}) = A * copy(B) -*(A::Diagonal, B::Transpose{<:Any,<:AbstractVector}) = A * copy(B) -*(A::Adjoint{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B -*(A::Transpose{<:Any,<:AbstractVector}, B::Diagonal) = copy(A) * B +*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x))) +*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = + mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) +*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x))) +*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) = + mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y)) # TODO: these methods will yield row matrices, rather than adjoint/transpose vectors diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 9619a5b630987..1c5bcbf027322 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -51,12 +51,8 @@ end # these will throw a DimensionMismatch unless B has 1 row (or 1 col for transposed case): *(a::AbstractVector, transB::Transpose{<:Any,<:AbstractMatrix}) = (B = transB.parent; *(reshape(a,length(a),1), transpose(B))) -*(A::AbstractMatrix, transb::Transpose{<:Any,<:AbstractVector}) = - (b = transb.parent; *(A, transpose(reshape(b,length(b),1)))) *(a::AbstractVector, adjB::Adjoint{<:Any,<:AbstractMatrix}) = (B = adjB.parent; *(reshape(a,length(a),1), adjoint(B))) -*(A::AbstractMatrix, adjb::Adjoint{<:Any,<:AbstractVector}) = - (b = adjb.parent; *(A, adjoint(reshape(b,length(b),1)))) (*)(a::AbstractVector, B::AbstractMatrix) = reshape(a,length(a),1)*B mul!(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) where {T<:BlasFloat} = gemv!(y, 'N', A, x) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 540111c0012a2..06a5d78ba2c84 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -433,4 +433,11 @@ end @test Diagonal(transpose([1, 2, 3])) == Diagonal([1 2 3]) end +@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin + x = rand(5) + D = Diagonal(rand(5)) + @test x'*D*x == (x'*D)*x == (x'*Array(D))*x + @test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x +end + end # module TestDiagonal