Skip to content

Commit

Permalink
Remove Diagonal BLAS.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jun 29, 2023
1 parent 3008b5b commit 21529a2
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ using CUDA
using GemmKernels
using LinearAlgebra

# Convert matrix to type compatible with kernel
convert_matrix(mat) = mat
convert_matrix(mat::Diagonal{T, A}) where {T, A} = mat.diag

# Select the best kernel
kernel(layout_a, layout_b) = Kernel.matmul_singlestage
kernel(::Type{Layout.AlignedColMajor{T}}, ::Type{Layout.AlignedColMajor{T}}) where {T} = Kernel.matmul_pipelined
Expand Down Expand Up @@ -66,7 +62,7 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number,
is_b_col_major = !transpose_b
)

GemmKernels.matmul(convert_matrix(A), convert_matrix(B), convert_matrix(C), convert_matrix(C), conf;
GemmKernels.matmul(A, B, C, C, conf;
transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha),
transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta),
kernel = kernel(global_a_layout, global_b_layout)
Expand Down

0 comments on commit 21529a2

Please sign in to comment.