Skip to content

Commit

Permalink
minor fixes in multiplication with Diagonals (#31443)
Browse files Browse the repository at this point in the history
* minor fixes in multiplication with Diagonals

* correct rmul!(A,D), revert changes in AdjTrans(x)*D

* [r/l]mul!: replace conj by adjoint, add transpose

* add tests

* fix typo

* relax some tests, added more tests

* simplify tests, strict equality
  • Loading branch information
dkarrasch authored and mbauman committed Apr 4, 2019
1 parent b471640 commit a93185f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
13 changes: 6 additions & 7 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ end

function rmul!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
A .= A .* transpose(D.diag)
A .= A .* permutedims(D.diag)
return A
end

Expand Down Expand Up @@ -260,20 +260,20 @@ lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)

function lmul!(adjA::Adjoint{<:Any,<:Diagonal}, B::AbstractMatrix)
A = adjA.parent
return lmul!(conj(A.diag), B)
return lmul!(adjoint(A), B)
end
function lmul!(transA::Transpose{<:Any,<:Diagonal}, B::AbstractMatrix)
A = transA.parent
return lmul!(A.diag, B)
return lmul!(transpose(A), B)
end

function rmul!(A::AbstractMatrix, adjB::Adjoint{<:Any,<:Diagonal})
B = adjB.parent
return rmul!(A, conj(B.diag))
return rmul!(A, adjoint(B))
end
function rmul!(A::AbstractMatrix, transB::Transpose{<:Any,<:Diagonal})
B = transB.parent
return rmul!(A, B.diag)
return rmul!(A, transpose(B))
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
Expand Down Expand Up @@ -552,10 +552,9 @@ end
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
# TODO: these methods will yield row matrices, rather than adjoint/transpose vectors

function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
info = 0
Expand Down
26 changes: 21 additions & 5 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,20 @@ end
fullBB = copyto!(Matrix{Matrix{T}}(undef, 2, 2), BB)
for (transform1, transform2) in ((identity, identity),
(identity, adjoint ), (adjoint, identity ), (adjoint, adjoint ),
(identity, transpose), (transpose, identity ), (transpose, transpose) )
(identity, transpose), (transpose, identity ), (transpose, transpose),
(identity, Adjoint ), (Adjoint, identity ), (Adjoint, Adjoint ),
(identity, Transpose), (Transpose, identity ), (Transpose, Transpose))
@test *(transform1(D), transform2(B))::typeof(D) *(transform1(Matrix(D)), transform2(Matrix(B))) atol=2 * eps()
@test *(transform1(DD), transform2(BB))::typeof(DD) == *(transform1(fullDD), transform2(fullBB))
end
M = randn(T, 5, 5)
MM = [randn(T, 2, 2) for _ in 1:2, _ in 1:2]
for transform in (identity, adjoint, transpose, Adjoint, Transpose)
@test lmul!(transform(D), copy(M)) == *(transform(Matrix(D)), M)
@test rmul!(copy(M), transform(D)) == *(M, transform(Matrix(D)))
@test lmul!(transform(DD), copy(MM)) == *(transform(fullDD), MM)
@test rmul!(copy(MM), transform(DD)) == *(MM, transform(fullDD))
end
end
end

Expand All @@ -481,10 +491,16 @@ end
end

@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin
x = rand(5)
D = Diagonal(rand(5))
@test x'*D*x == (x'*D)*x == (x'*Array(D))*x
@test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x
x = collect(1:2)
xt = transpose(x)
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
D = Diagonal(A)
@test x'*D == x'*A == copy(x')*D == copy(x')*A
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
y = [x, x]
yt = transpose(y)
@test y'*D*y == (y'*D)*y == (y'*A)*y
@test yt*D*y == (yt*D)*y == (yt*A)*y
end

@testset "Triangular division by Diagonal #27989" begin
Expand Down

0 comments on commit a93185f

Please sign in to comment.