diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 30b410bbff6ef..a453078711d1a 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -152,15 +152,23 @@ function (*)(A::AbstractMatrix, B::AbstractMatrix) TS = promote_op(matprod, eltype(A), eltype(B)) mul!(similar(B, TS, (size(A,1), size(B,2))), A, B) end -@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat} = - gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + +@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:BlasFloat} + alpha, beta = promote(α, β, zero(T)) + if alpha isa T && beta isa T + return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + else + return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(α, β)) + end +end + # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the # first matrix as a real matrix and carry out real matrix matrix multiply for elty in (Float32,Float64) @eval begin @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty}, - alpha::Union{$elty, Bool}, beta::Union{$elty, Bool}) + alpha::Real, beta::Real) Afl = reinterpret($elty, A) Cfl = reinterpret($elty, C) mul!(Cfl, Afl, B, alpha, beta) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 794ec672297a4..46dc2aea3a903 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -586,7 +586,7 @@ end A = rand(n, n) B = rand(n, n) C = zeros(n, n) - mul!(C, A, B, -1, 0) + mul!(C, A, B, -1 + 0im, 0) D = -A * B @test D ≈ C