Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disambiguate mul! for matvec and matmat through indirection #52837

Merged
merged 6 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Standard library changes
* There is now a specialized dispatch for `eigvals/eigen(::Hermitian{<:Tridiagonal})` which performs a similarity transformation to create a real symmetrix triagonal matrix, and solve that using the LAPACK routines ([#49546]).
* Structured matrices now retain either the axes of the parent (for `Symmetric`/`Hermitian`/`AbstractTriangular`/`UpperHessenberg`), or that of the principal diagonal (for banded matrices) ([#52480]).
* `bunchkaufman` and `bunchkaufman!` now work for any `AbstractFloat`, `Rational` and their complex variants. `bunchkaufman` now supports `Integer` types, by making an internal conversion to `Rational{BigInt}`. Added new function `inertia` that computes the inertia of the diagonal factor given by the `BunchKaufman` factorization object of a real symmetric or Hermitian matrix. For complex symmetric matrices, `inertia` only computes the number of zero eigenvalues of the diagonal factor ([#51487]).
* Packages that specialize matrix-matrix `mul!` with a method signature of the form `mul!(::AbstractMatrix, ::MyMatrix, ::AbstractMatrix, ::Number, ::Number)` no longer encounter method ambiguities when interacting with `LinearAlgebra`. Previously, ambiguities used to arise when multiplying a `MyMatrix` with a structured matrix type provided by LinearAlgebra, such as `AbstractTriangular`, which used to necessitate additional methods to resolve such ambiguities. Similar sources of ambiguities have also been removed for matrix-vector `mul!` operations ([#52837]).

#### Logging
* New `@create_log_macro` macro for creating new log macros like `@info`, `@warn` etc. For instance
Expand Down
26 changes: 13 additions & 13 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,11 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const TriSym = Union{Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
Expand Down Expand Up @@ -465,9 +465,9 @@ function _diag(A::Bidiagonal, k)
end
end

_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
function _bibimul!(C, A, B, _add)
check_A_mul_B!_sizes(C, A, B)
Expand Down Expand Up @@ -526,7 +526,7 @@ function _bibimul!(C, A, B, _add)
C
end

function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
Expand Down Expand Up @@ -562,7 +562,7 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul = Mu
C
end

function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
require_one_based_indexing(C, B)
nA = size(A,1)
nB = size(B,2)
Expand Down Expand Up @@ -592,7 +592,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
C
end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -627,7 +627,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul
C
end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand All @@ -653,9 +653,9 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
C
end

_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
function _dibimul!(C, A, B, _add)
require_one_based_indexing(C)
Expand Down
9 changes: 8 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ end
(*)(a::AbstractVector, adjB::AdjointAbsMat) = reshape(a, length(a), 1) * adjB
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a, length(a), 1) * B

# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) = _mul!(y, A, x, alpha, beta)

@inline _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta))

# BLAS cases
# equal eltypes
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
Expand Down Expand Up @@ -277,7 +282,9 @@ julia> C == A * B * α + C_original * β
true
```
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
generic_matmatmul!(
C,
wrapper_char(A),
Expand Down
6 changes: 2 additions & 4 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,9 @@ for op in (:+, :-)
end

# disambiguation between triangular and banded matrices, banded ones "dominate"
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
_mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
_mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))

function *(H::UpperHessenberg, B::Bidiagonal)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ rmul!(A::AbstractMatrix, B::AbstractTriangular) = @inline _trimul!(A, A, B)


for TC in (:AbstractVector, :AbstractMatrix)
@eval @inline function mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number)
@eval @inline function _mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
Expand All @@ -815,7 +815,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix),
(:AbstractMatrix, :AbstractTriangular),
(:AbstractTriangular, :AbstractTriangular)
)
@eval @inline function mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number)
@eval @inline function _mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
Expand Down
23 changes: 23 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using .Main.FillArrays
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
using .Main.OffsetArrays

isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

include("testutils.jl") # test_approx_eq_modphase

n = 10 #Size of test matrix
Expand Down Expand Up @@ -861,4 +864,24 @@ end
@test axes(B) === (ax, ax)
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
B = Bidiagonal([1:2;], [1;], :U)
@test S * B == A * B
@test B * S == B * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, B) == mul!(C2, A, B)
@test mul!(C1, S, B, 1, 2) == mul!(C2, A, B, 1 ,2)
@test mul!(C1, B, S) == mul!(C2, B, A)
@test mul!(C1, B, S, 1, 2) == mul!(C2, B, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test B * sv == B * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, B, sv) == mul!(C2, B, v)
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
end

end # module TestBidiagonal
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,18 @@ end
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, D) == mul!(C2, A, D)
@test mul!(C1, S, D, 1, 2) == mul!(C2, A, D, 1 ,2)
@test mul!(C1, D, S) == mul!(C2, D, A)
@test mul!(C1, D, S, 1, 2) == mul!(C2, D, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test D * sv == D * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, D, sv) == mul!(C2, D, v)
@test mul!(C1, D, sv, 1, 2) == mul!(C2, D, v, 1 ,2)
end

@testset "copy" begin
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,18 @@ end
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, U) == mul!(C2, A, U)
@test mul!(C1, S, U, 1, 2) == mul!(C2, A, U, 1 ,2)
@test mul!(C1, U, S) == mul!(C2, U, A)
@test mul!(C1, U, S, 1, 2) == mul!(C2, U, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test U * sv == U * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, U, sv) == mul!(C2, U, v)
@test mul!(C1, U, sv, 1, 2) == mul!(C2, U, v, 1 ,2)
end

@testset "custom axes" begin
Expand Down
20 changes: 15 additions & 5 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module SizedArrays
import Base: +, *, ==

using LinearAlgebra
import LinearAlgebra: mul!

export SizedArray

Expand All @@ -33,6 +34,8 @@ struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
new{SZ,T,N,A}(A(data))
end
end
SizedMatrix{SZ,T,A<:AbstractArray} = SizedArray{SZ,T,2,A}
SizedVector{SZ,T,A<:AbstractArray} = SizedArray{SZ,T,1,A}
Base.convert(::Type{SizedArray{SZ,T,N,A}}, data::AbstractArray) where {SZ,T,N,A} = SizedArray{SZ,T,N,A}(data)

# Minimal AbstractArray interface
Expand All @@ -44,21 +47,28 @@ Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T)
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

const SizedArrayLike = Union{SizedArray, Transpose{<:Any, <:SizedArray}, Adjoint{<:Any, <:SizedArray}}
const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}

_data(S::SizedArray) = S.data
_data(T::Transpose{<:Any, <:SizedArray}) = transpose(_data(parent(T)))
_data(T::Adjoint{<:Any, <:SizedArray}) = adjoint(_data(parent(T)))

function *(S1::SizedArrayLike, S2::SizedArrayLike)
function *(S1::SizedMatrixLike, S2::SizedMatrixLike)
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
data = _data(S1) * _data(S2)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M
# deliberately wide method definitions to test for method ambiguties in LinearAlgebra
*(S1::SizedMatrixLike, M::AbstractMatrix) = _data(S1) * M
mul!(dest::AbstractMatrix, S1::SizedMatrix, M::AbstractMatrix, α::Number, β::Number) =
mul!(dest, _data(S1), M, α, β)
mul!(dest::AbstractMatrix, M::AbstractMatrix, S2::SizedMatrix, α::Number, β::Number) =
mul!(dest, M, _data(S2), α, β)
mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Number) =
mul!(dest, _data(S1), _data(S2), α, β)
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
mul!(dest, M, _data(v), α, β)

end