diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 7fa912c3ff..0c1ccf6b24 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -168,7 +168,16 @@ function scal!(n::Integer, alpha::Real, x::StridedCuVecOrDenseMat{T}) where {T<: end function scal!(n::Integer, alpha::Real, x::StridedCuVecOrDenseMat{ComplexF16}) wide_x = widen.(x) - scal!(n, alpha, wide_x) + gpu_α = CuRef{Float32}( convert(Float32, alpha) ) + scal!(n, gpu_α, wide_x) + thin_x = convert(typeof(x), wide_x) + copyto!(x, thin_x) + return x +end +function scal!(n::Integer, alpha::Complex, x::StridedCuVecOrDenseMat{ComplexF16}) + wide_x = widen.(x) + gpu_α = CuRef{ComplexF32}( convert(ComplexF32, alpha) ) + scal!(n, gpu_α, wide_x) thin_x = convert(typeof(x), wide_x) copyto!(x, thin_x) return x