Skip to content

Commit

Permalink
Disambiguate mul! for matvec and matmat through indirection (#52837)
Browse files Browse the repository at this point in the history
This forwards `mul!` for matrix-matrix and matrix-vector multiplications
to the internal `_mul!`, which is then specialized for the various array
types. With this, packages would not encounter any method ambiguity when
they define
```julia
mul!(::AbstractMatrix, ::MyMatrix, ::AbstractMatrix, alpha::Number, beta::Number)
```
This should reduce the number of methods that packages need to define to
work around ambiguities with `LinearAlgebra`.

There was already an existing internal function named `_mul!`, but the
new methods don't clash with the existing ones, and since both sets of
methods usually forward structured multiplications to specialized
functions, it's fitting for them to have the same name.
  • Loading branch information
jishnub authored Jan 10, 2024
1 parent 124ce94 commit 2afc20c
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 25 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Standard library changes
* There is now a specialized dispatch for `eigvals/eigen(::Hermitian{<:Tridiagonal})` which performs a similarity transformation to create a real symmetrix triagonal matrix, and solve that using the LAPACK routines ([#49546]).
* Structured matrices now retain either the axes of the parent (for `Symmetric`/`Hermitian`/`AbstractTriangular`/`UpperHessenberg`), or that of the principal diagonal (for banded matrices) ([#52480]).
* `bunchkaufman` and `bunchkaufman!` now work for any `AbstractFloat`, `Rational` and their complex variants. `bunchkaufman` now supports `Integer` types, by making an internal conversion to `Rational{BigInt}`. Added new function `inertia` that computes the inertia of the diagonal factor given by the `BunchKaufman` factorization object of a real symmetric or Hermitian matrix. For complex symmetric matrices, `inertia` only computes the number of zero eigenvalues of the diagonal factor ([#51487]).
* Packages that specialize matrix-matrix `mul!` with a method signature of the form `mul!(::AbstractMatrix, ::MyMatrix, ::AbstractMatrix, ::Number, ::Number)` no longer encounter method ambiguities when interacting with `LinearAlgebra`. Previously, ambiguities used to arise when multiplying a `MyMatrix` with a structured matrix type provided by LinearAlgebra, such as `AbstractTriangular`, which used to necessitate additional methods to resolve such ambiguities. Similar sources of ambiguities have also been removed for matrix-vector `mul!` operations ([#52837]).

#### Logging
* New `@create_log_macro` macro for creating new log macros like `@info`, `@warn` etc. For instance
Expand Down
26 changes: 13 additions & 13 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,11 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or
const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal}
const TriSym = Union{Tridiagonal,SymTridiagonal}
const BiTri = Union{Bidiagonal,Tridiagonal}
@inline mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
Expand Down Expand Up @@ -465,9 +465,9 @@ function _diag(A::Bidiagonal, k)
end
end

_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::BiTriSym, B::TriSym, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::BiTriSym, B::Bidiagonal, _add::MulAddMul) =
_bibimul!(C, A, B, _add)
function _bibimul!(C, A, B, _add)
check_A_mul_B!_sizes(C, A, B)
Expand Down Expand Up @@ -526,7 +526,7 @@ function _bibimul!(C, A, B, _add)
C
end

function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
Expand Down Expand Up @@ -562,7 +562,7 @@ function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul = Mu
C
end

function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul)
require_one_based_indexing(C, B)
nA = size(A,1)
nB = size(B,2)
Expand Down Expand Up @@ -592,7 +592,7 @@ function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulA
C
end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand Down Expand Up @@ -627,7 +627,7 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::TriSym, _add::MulAddMul
C
end

function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul = MulAddMul())
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAddMul)
require_one_based_indexing(C, A)
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
Expand All @@ -653,9 +653,9 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
C
end

_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul = MulAddMul()) =
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
function _dibimul!(C, A, B, _add)
require_one_based_indexing(C)
Expand Down
9 changes: 8 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ end
(*)(a::AbstractVector, adjB::AdjointAbsMat) = reshape(a, length(a), 1) * adjB
(*)(a::AbstractVector, B::AbstractMatrix) = reshape(a, length(a), 1) * B

# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) = _mul!(y, A, x, alpha, beta)

@inline _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
alpha::Number, beta::Number) =
generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta))

# BLAS cases
# equal eltypes
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
Expand Down Expand Up @@ -277,7 +282,9 @@ julia> C == A * B * α + C_original * β
true
```
"""
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
generic_matmatmul!(
C,
wrapper_char(A),
Expand Down
6 changes: 2 additions & 4 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,9 @@ for op in (:+, :-)
end

# disambiguation between triangular and banded matrices, banded ones "dominate"
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular) = _mul!(C, A, B, MulAddMul())
mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
_mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))
mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
_mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) =
_mul!(C, A, B, MulAddMul(alpha, beta))

function *(H::UpperHessenberg, B::Bidiagonal)
Expand Down
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ rmul!(A::AbstractMatrix, B::AbstractTriangular) = @inline _trimul!(A, A, B)


for TC in (:AbstractVector, :AbstractMatrix)
@eval @inline function mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number)
@eval @inline function _mul!(C::$TC, A::AbstractTriangular, B::AbstractVector, alpha::Number, beta::Number)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
Expand All @@ -813,7 +813,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix),
(:AbstractMatrix, :AbstractTriangular),
(:AbstractTriangular, :AbstractTriangular)
)
@eval @inline function mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number)
@eval @inline function _mul!(C::AbstractMatrix, A::$TA, B::$TB, alpha::Number, beta::Number)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
Expand Down
23 changes: 23 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using .Main.FillArrays
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
using .Main.OffsetArrays

isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

include("testutils.jl") # test_approx_eq_modphase

n = 10 #Size of test matrix
Expand Down Expand Up @@ -861,4 +864,24 @@ end
@test axes(B) === (ax, ax)
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
B = Bidiagonal([1:2;], [1;], :U)
@test S * B == A * B
@test B * S == B * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, B) == mul!(C2, A, B)
@test mul!(C1, S, B, 1, 2) == mul!(C2, A, B, 1 ,2)
@test mul!(C1, B, S) == mul!(C2, B, A)
@test mul!(C1, B, S, 1, 2) == mul!(C2, B, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test B * sv == B * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, B, sv) == mul!(C2, B, v)
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
end

end # module TestBidiagonal
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,18 @@ end
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, D) == mul!(C2, A, D)
@test mul!(C1, S, D, 1, 2) == mul!(C2, A, D, 1 ,2)
@test mul!(C1, D, S) == mul!(C2, D, A)
@test mul!(C1, D, S, 1, 2) == mul!(C2, D, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test D * sv == D * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, D, sv) == mul!(C2, D, v)
@test mul!(C1, D, sv, 1, 2) == mul!(C2, D, v, 1 ,2)
end

@testset "copy" begin
Expand Down
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,18 @@ end
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
C1, C2 = zeros(2,2), zeros(2,2)
@test mul!(C1, S, U) == mul!(C2, A, U)
@test mul!(C1, S, U, 1, 2) == mul!(C2, A, U, 1 ,2)
@test mul!(C1, U, S) == mul!(C2, U, A)
@test mul!(C1, U, S, 1, 2) == mul!(C2, U, A, 1 ,2)

v = [i for i in 1:2]
sv = SizedArrays.SizedArray{(2,)}(v)
@test U * sv == U * v
C1, C2 = zeros(2), zeros(2)
@test mul!(C1, U, sv) == mul!(C2, U, v)
@test mul!(C1, U, sv, 1, 2) == mul!(C2, U, v, 1 ,2)
end

@testset "custom axes" begin
Expand Down
20 changes: 15 additions & 5 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module SizedArrays
import Base: +, *, ==

using LinearAlgebra
import LinearAlgebra: mul!

export SizedArray

Expand All @@ -33,6 +34,8 @@ struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
new{SZ,T,N,A}(A(data))
end
end
SizedMatrix{SZ,T,A<:AbstractArray} = SizedArray{SZ,T,2,A}
SizedVector{SZ,T,A<:AbstractArray} = SizedArray{SZ,T,1,A}
Base.convert(::Type{SizedArray{SZ,T,N,A}}, data::AbstractArray) where {SZ,T,N,A} = SizedArray{SZ,T,N,A}(data)

# Minimal AbstractArray interface
Expand All @@ -44,21 +47,28 @@ Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T)
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

const SizedArrayLike = Union{SizedArray, Transpose{<:Any, <:SizedArray}, Adjoint{<:Any, <:SizedArray}}
const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}}

_data(S::SizedArray) = S.data
_data(T::Transpose{<:Any, <:SizedArray}) = transpose(_data(parent(T)))
_data(T::Adjoint{<:Any, <:SizedArray}) = adjoint(_data(parent(T)))

function *(S1::SizedArrayLike, S2::SizedArrayLike)
function *(S1::SizedMatrixLike, S2::SizedMatrixLike)
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
data = _data(S1) * _data(S2)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M
# deliberately wide method definitions to test for method ambiguties in LinearAlgebra
*(S1::SizedMatrixLike, M::AbstractMatrix) = _data(S1) * M
mul!(dest::AbstractMatrix, S1::SizedMatrix, M::AbstractMatrix, α::Number, β::Number) =
mul!(dest, _data(S1), M, α, β)
mul!(dest::AbstractMatrix, M::AbstractMatrix, S2::SizedMatrix, α::Number, β::Number) =
mul!(dest, M, _data(S2), α, β)
mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Number) =
mul!(dest, _data(S1), _data(S2), α, β)
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
mul!(dest, M, _data(v), α, β)

end

0 comments on commit 2afc20c

Please sign in to comment.