From 21529a2a200af7b68c7093d571d5dd957d74b8e3 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 29 Jun 2023 20:32:52 +0200 Subject: [PATCH] Remove Diagonal BLAS. --- src/blas.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/blas.jl b/src/blas.jl index 8a83b7dd..ee9c2b57 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -4,10 +4,6 @@ using CUDA using GemmKernels using LinearAlgebra -# Convert matrix to type compatible with kernel -convert_matrix(mat) = mat -convert_matrix(mat::Diagonal{T, A}) where {T, A} = mat.diag - # Select the best kernel kernel(layout_a, layout_b) = Kernel.matmul_singlestage kernel(::Type{Layout.AlignedColMajor{T}}, ::Type{Layout.AlignedColMajor{T}}) where {T} = Kernel.matmul_pipelined @@ -66,7 +62,7 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number, is_b_col_major = !transpose_b ) - GemmKernels.matmul(convert_matrix(A), convert_matrix(B), convert_matrix(C), convert_matrix(C), conf; + GemmKernels.matmul(A, B, C, C, conf; transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha), transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta), kernel = kernel(global_a_layout, global_b_layout)