Skip to content

Commit

Permalink
Resolve some adj/trans and triangular matrix multiplication ambiguiti…
Browse files Browse the repository at this point in the history
…es (#325)

* Remove ambiguity in transpose matrix * zeros vector

* Resolve transpose vec * Zeros Matrix

* disambiguate with transpose-adjoint wrapper

* disambiguate against AbstractTriangular
  • Loading branch information
jishnub authored Dec 6, 2023
1 parent cf8c78d commit 5840fc6
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron
issymmetric, ishermitian, AdjOrTransAbsVec, checksquare, mul!, kron, AbstractTriangular


import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape, BroadcastStyle, Broadcasted
Expand Down
38 changes: 30 additions & 8 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,18 @@ mult_ones(a, b) = mult_ones(a, b, mult_axes(a, b))
*(a::AbstractFillMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
*(a::AbstractFillMatrix, b::AbstractZerosVector) = mult_zeros(a, b)

*(a::AbstractZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
*(a::AbstractMatrix, b::AbstractZerosVector) = mult_zeros(a, b)
*(a::AbstractMatrix, b::AbstractZerosMatrix) = mult_zeros(a, b)
for MT in (:AbstractMatrix, :AbstractTriangular)
@eval *(a::AbstractZerosMatrix, b::$MT) = mult_zeros(a, b)
@eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b)
end
# Odd way to deal with the type-parameters to avoid ambiguities
for MT in (:(AbstractMatrix{T}), :(Transpose{<:Any, <:AbstractMatrix{T}}), :(Adjoint{<:Any, <:AbstractMatrix{T}}),
:(AbstractTriangular{T}))
@eval *(a::$MT, b::AbstractZerosVector) where {T} = mult_zeros(a, b)
end
for MT in (:(Transpose{<:Any, <:AbstractVector}), :(Adjoint{<:Any, <:AbstractVector}))
@eval *(a::$MT, b::AbstractZerosMatrix) = mult_zeros(a, b)
end
*(a::AbstractZerosMatrix, b::AbstractVector) = mult_zeros(a, b)

function lmul_diag(a::Diagonal, b)
Expand Down Expand Up @@ -290,13 +299,25 @@ function _adjvec_mul_zeros(a, b)
return a1 * b[1]
end

*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = (b' * a')'
for MT in (:AbstractMatrix, :AbstractTriangular, :(Adjoint{<:Any,<:TransposeAbsVec}))
@eval *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = (b' * a')'
end
# ambiguity
function *(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::TransposeAbsVec{<:Any,<:AdjointAbsVec})
# change from Transpose ∘ Adjoint to Adjoint ∘ Transpose
b2 = adjoint(transpose(adjoint(transpose(b))))
a * b2
end
*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = (b' * a')'
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a))
for MT in (:AbstractMatrix, :AbstractTriangular, :(Transpose{<:Any,<:AdjointAbsVec}))
@eval *(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::$MT) = transpose(transpose(b) * transpose(a))
end
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, b::AbstractZerosMatrix) = transpose(transpose(b) * transpose(a))

*(a::AbstractVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
*(a::AbstractMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
for MT in (:AbstractMatrix, :AbstractTriangular)
@eval *(a::$MT, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
end
*(a::AbstractZerosVector, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))
*(a::AbstractZerosMatrix, b::AdjOrTransAbsVec{<:Any,<:AbstractZerosVector}) = a * permutedims(parent(b))

Expand All @@ -307,7 +328,8 @@ end

*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::AbstractZeros{<:Any, 1}) = mult_zeros(a, b)

*(D::Diagonal, a::AdjointAbsVec{<:Any,<:AbstractZerosVector}) = (a' * D')'
*(D::Diagonal, a::Adjoint{<:Any,<:AbstractZerosVector}) = (a' * D')'
*(D::Diagonal, a::Transpose{<:Any,<:AbstractZerosVector}) = transpose(transpose(a) * transpose(D))
*(a::AdjointAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = (D' * a')'
*(a::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal) = transpose(D*transpose(a))
function _triple_zeromul(x, D::Diagonal, y)
Expand All @@ -325,7 +347,7 @@ end
*(x::TransposeAbsVec{<:Any,<:AbstractZerosVector}, D::Diagonal, y::AbstractZerosVector) = _triple_zeromul(x, D, y)


function *(a::Transpose{T, <:AbstractVector{T}}, b::AbstractZerosVector{T}) where T<:Real
function *(a::Transpose{T, <:AbstractVector}, b::AbstractZerosVector{T}) where T<:Real
la, lb = length(a), length(b)
if la lb
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
Expand Down
24 changes: 24 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,14 @@ end
@test A*Zeros(nA,1) Zeros(mA,1)
@test a*Zeros(na,3) Zeros(la,3)

@test transpose(A) * Zeros(mA) Zeros(nA)
@test A' * Zeros(mA) Zeros(nA)

@test transpose(a) * Zeros(la, 3) Zeros(1,3)
@test a' * Zeros(la,3) Zeros(1,3)

@test Zeros(la)' * Transpose(Adjoint(a)) == 0.0

w = zeros(mA)
@test mul!(w, A, Fill(2,nA), true, false) A * fill(2,nA)
w .= 2
Expand Down Expand Up @@ -1658,6 +1666,22 @@ end
@test adjoint(A)*fillvec adjoint(A)*Array(fillvec)
@test adjoint(A)*fillmat adjoint(A)*Array(fillmat)
end

@testset "ambiguities" begin
UT33 = UpperTriangular(ones(3,3))
UT11 = UpperTriangular(ones(1,1))
@test transpose(Zeros(3)) * Transpose(Adjoint([1,2,3])) == 0
@test Zeros(3)' * Adjoint(Transpose([1,2,3])) == 0
@test Zeros(3)' * UT33 == Zeros(3)'
@test transpose(Zeros(3)) * UT33 == transpose(Zeros(3))
@test UT11 * Zeros(3)' == Zeros(1,3)
@test UT11 * transpose(Zeros(3)) == Zeros(1,3)
@test Zeros(2,3) * UT33 == Zeros(2,3)
@test UT33 * Zeros(3,2) == Zeros(3,2)
@test UT33 * Zeros(3) == Zeros(3)
@test Diagonal([1]) * transpose(Zeros(3)) == Zeros(1,3)
@test Diagonal([1]) * Zeros(3)' == Zeros(1,3)
end
end

@testset "count" begin
Expand Down

0 comments on commit 5840fc6

Please sign in to comment.