Skip to content

Commit

Permalink
fix scal! for gpuarrays
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt committed Jan 14, 2025
1 parent 34a913f commit a0829fa
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,16 @@ function scal!(n::Integer, alpha::Real, x::StridedCuVecOrDenseMat{T}) where {T<:
end
function scal!(n::Integer, alpha::Real, x::StridedCuVecOrDenseMat{ComplexF16})
wide_x = widen.(x)
scal!(n, alpha, wide_x)
gpu_α = CuRef{Float32}( convert(Float32, alpha) )
scal!(n, gpu_α, wide_x)
thin_x = convert(typeof(x), wide_x)
copyto!(x, thin_x)
return x
end
function scal!(n::Integer, alpha::Complex, x::StridedCuVecOrDenseMat{ComplexF16})
wide_x = widen.(x)
gpu_α = CuRef{ComplexF32}( convert(ComplexF32, alpha) )
scal!(n, gpu_α, wide_x)
thin_x = convert(typeof(x), wide_x)
copyto!(x, thin_x)
return x
Expand Down

0 comments on commit a0829fa

Please sign in to comment.