diff --git a/src/blas.jl b/src/blas.jl index 93a37113..594249d1 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -67,7 +67,7 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number, ) GemmKernels.matmul(convert_matrix(A), convert_matrix(B), convert_matrix(C), convert_matrix(C), conf; - transform_shared_to_regs_a = Transform.Elementwise(x -> multiply_fp16(x, alpha)), + transform_shared_to_regs_a = Transform.Elementwise(x -> multiply_fp16(x, convert(Float16, alpha))), transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta)) end