From 5840fc61471124205cfd27853e128558c6e131cb Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 6 Dec 2023 22:26:01 +0530 Subject: [PATCH] Resolve some adj/trans and triangular matrix multiplication ambiguities (#325) * Remove ambiguity in transpose matrix * zeros vector * Resolve transpose vec * Zeros Matrix * disambiguate with transpose-adjoint wrapper * disambiguate against AbstractTriangular --- src/FillArrays.jl | 2 +- src/fillalgebra.jl | 38 ++++++++++++++++++++++++++++++-------- test/runtests.jl | 24 ++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 6b68c41f..ac4f4ee6 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -11,7 +11,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec, - issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron + issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron, AbstractTriangular import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index c849ac22..043cc4eb 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -93,9 +93,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b)) *(a::AbstractFillMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b) *(a::AbstractFillMatrix, b::AbstractZerosVector) = mult_zeros(a, b) -*(a::AbstractZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b) -*(a::AbstractMatrix, b::AbstractZerosVector) = mult_zeros(a, b) -*(a::AbstractMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b) +for MT in (:AbstractMatrix, :AbstractTriangular) + @eval *(a::AbstractZerosMatrix, b::$MT) = mult_zeros(a, b) + @eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b) +end +# Odd way to deal with the type-parameters to avoid ambiguities +for MT in (:(AbstractMatrix{T}), :(Transpose{<:Any, <:AbstractMatrix{T}}), :(Adjoint{<:Any, <:AbstractMatrix{T}}), + :(AbstractTriangular{T})) + @eval *(a::$MT, b::AbstractZerosVector) where {T} = mult_zeros(a, b) +end +for MT in (:(Transpose{<:Any, <:AbstractVector}), :(Adjoint{<:Any, <:AbstractVector})) + @eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b) +end *(a::AbstractZerosMatrix, b::AbstractVector) = mult_zeros(a, b) function lmul_diag(a::Diagonal, b) @@ -290,13 +299,25 @@ function _adjvec_mul_zeros(a, b) return a1 * b[1] end -*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = (b' * a')' +for MT in (:AbstractMatrix, :AbstractTriangular, :(Adjoint{<:Any,<:TransposeAbsVec})) + @eval *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = (b' * a')' +end +# ambiguity +function *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::TransposeAbsVec{<:Any,<:AdjointAbsVec}) + # change from Transpose ∘ Adjoint to Adjoint ∘ Transpose + b2 = adjoint(transpose(adjoint(transpose(b)))) + a * b2 +end *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = (b' * a')' -*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a)) +for MT in (:AbstractMatrix, :AbstractTriangular, :(Transpose{<:Any,<:AdjointAbsVec})) + @eval *(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = transpose(transpose(b) * transpose(a)) +end *(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = transpose(transpose(b) * transpose(a)) *(a::AbstractVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b)) -*(a::AbstractMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b)) +for MT in (:AbstractMatrix, :AbstractTriangular) + @eval *(a::$MT, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b)) +end *(a::AbstractZerosVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b)) *(a::AbstractZerosMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b)) @@ -307,7 +328,8 @@ end *(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::AbstractZeros{<:Any, 1}) = mult_zeros(a, b) -*(D::Diagonal, a::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = (a' * D')' +*(D::Diagonal, a::Adjoint{<:Any,<:AbstractZerosVector}) = (a' * D')' +*(D::Diagonal, a::Transpose{<:Any,<:AbstractZerosVector}) = transpose(transpose(a) * transpose(D)) *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = (D' * a')' *(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = transpose(D*transpose(a)) function _triple_zeromul(x, D::Diagonal, y) @@ -325,7 +347,7 @@ end *(x::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal, y::AbstractZerosVector) = _triple_zeromul(x, D, y) -function *(a::Transpose{T, <:AbstractVector{T}}, b::AbstractZerosVector{T}) where T<:Real +function *(a::Transpose{T, <:AbstractVector}, b::AbstractZerosVector{T}) where T<:Real la, lb = length(a), length(b) if la ≠ lb throw(DimensionMismatch("dot product arguments have lengths $la and $lb")) diff --git a/test/runtests.jl b/test/runtests.jl index 7ef4b1a8..e67b62b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1579,6 +1579,14 @@ end @test A*Zeros(nA,1) ≡ Zeros(mA,1) @test a*Zeros(na,3) ≡ Zeros(la,3) + @test transpose(A) * Zeros(mA) ≡ Zeros(nA) + @test A' * Zeros(mA) ≡ Zeros(nA) + + @test transpose(a) * Zeros(la, 3) ≡ Zeros(1,3) + @test a' * Zeros(la,3) ≡ Zeros(1,3) + + @test Zeros(la)' * Transpose(Adjoint(a)) == 0.0 + w = zeros(mA) @test mul!(w, A, Fill(2,nA), true, false) ≈ A * fill(2,nA) w .= 2 @@ -1658,6 +1666,22 @@ end @test adjoint(A)*fillvec ≈ adjoint(A)*Array(fillvec) @test adjoint(A)*fillmat ≈ adjoint(A)*Array(fillmat) end + + @testset "ambiguities" begin + UT33 = UpperTriangular(ones(3,3)) + UT11 = UpperTriangular(ones(1,1)) + @test transpose(Zeros(3)) * Transpose(Adjoint([1,2,3])) == 0 + @test Zeros(3)' * Adjoint(Transpose([1,2,3])) == 0 + @test Zeros(3)' * UT33 == Zeros(3)' + @test transpose(Zeros(3)) * UT33 == transpose(Zeros(3)) + @test UT11 * Zeros(3)' == Zeros(1,3) + @test UT11 * transpose(Zeros(3)) == Zeros(1,3) + @test Zeros(2,3) * UT33 == Zeros(2,3) + @test UT33 * Zeros(3,2) == Zeros(3,2) + @test UT33 * Zeros(3) == Zeros(3) + @test Diagonal([1]) * transpose(Zeros(3)) == Zeros(1,3) + @test Diagonal([1]) * Zeros(3)' == Zeros(1,3) + end end @testset "count" begin