From 51b32274dfed800470583b6a2441c0766b036910 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Fri, 13 Sep 2019 02:24:05 -0700 Subject: [PATCH] Dispatch more cases to BLAS.gemm! (#33229) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Dispatch more cases to BLAS.gemm! * Use α and β instead of alpha′ and beta′ --- stdlib/LinearAlgebra/src/matmul.jl | 16 ++++++++++++---- stdlib/LinearAlgebra/test/matmul.jl | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 30b410bbff6ef..a453078711d1a 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -152,15 +152,23 @@ function (*)(A::AbstractMatrix, B::AbstractMatrix) TS = promote_op(matprod, eltype(A), eltype(B)) mul!(similar(B, TS, (size(A,1), size(B,2))), A, B) end -@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - alpha::Union{T, Bool}, beta::Union{T, Bool}) where {T<:BlasFloat} = - gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + +@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number, β::Number) where {T<:BlasFloat} + alpha, beta = promote(α, β, zero(T)) + if alpha isa T && beta isa T + return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + else + return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(α, β)) + end +end + # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the # first matrix as a real matrix and carry out real matrix matrix multiply for elty in (Float32,Float64) @eval begin @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty}, - alpha::Union{$elty, Bool}, beta::Union{$elty, Bool}) + alpha::Real, beta::Real) Afl = reinterpret($elty, A) Cfl = reinterpret($elty, C) mul!(Cfl, Afl, B, alpha, beta) diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 794ec672297a4..46dc2aea3a903 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -586,7 +586,7 @@ end A = rand(n, n) B = rand(n, n) C = zeros(n, n) - mul!(C, A, B, -1, 0) + mul!(C, A, B, -1 + 0im, 0) D = -A * B @test D ≈ C