Skip to content

Commit

Permalink
Send AdjOrTrans(BlasMatrix) oop-products to BLAS (#43435)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Jan 12, 2022
1 parent 29ea1fe commit 08ff456
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
5 changes: 3 additions & 2 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ const AdjOrTransAbsVec{T} = AdjOrTrans{T,<:AbstractVector}
const AdjOrTransAbsMat{T} = AdjOrTrans{T,<:AbstractMatrix}

# for internal use below
wrapperop(A::Adjoint) = adjoint
wrapperop(A::Transpose) = transpose
wrapperop(_) = identity
wrapperop(::Adjoint) = adjoint
wrapperop(::Transpose) = transpose

# AbstractArray interface, basic definitions
length(A::AdjOrTrans) = length(A.parent)
Expand Down
45 changes: 38 additions & 7 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ end

# Matrix-matrix multiplication

AdjOrTransStridedMat{T} = Union{Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

_parent(A) = A
_parent(A::Adjoint) = parent(A)
_parent(A::Transpose) = parent(A)

"""
*(A::AbstractMatrix, B::AbstractMatrix)
Expand All @@ -130,18 +137,22 @@ julia> [1 1; 0 1] * [1 0; 1 1]
"""
function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), A, B)
mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
end
# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasFloat})
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), convert(AbstractArray{TS}, A), convert(AbstractArray{TS}, B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
end
function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMatrix{<:BlasComplex})
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A,1), size(B,2))), convert(AbstractArray{TS}, A), convert(AbstractArray{TS}, B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
end

@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
Expand All @@ -150,6 +161,28 @@ end
end
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
# first matrix as a real matrix and carry out real matrix matrix multiply
function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
end
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
copy_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
end
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
convert(AbstractArray{TS}, B))
end
(*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A)))
(*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A)))

for elty in (Float32,Float64)
@eval begin
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
Expand Down Expand Up @@ -213,8 +246,6 @@ Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrM
Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) =
transpose(muladd(transpose(A), transpose(x), transpose(z)))

StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}}

function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector})
T = promote_type(eltype(A), eltype(y), eltype(z))
C = similar(A, T, axes(A,1))
Expand Down

0 comments on commit 08ff456

Please sign in to comment.