diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index 10a8a2ee09..ead34d4262 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -591,3 +591,15 @@ end error("only supports BLAS type, got $T") end end + +op_wrappers = ((identity, T -> 'N', identity), + (T -> :(Transpose{T, <:$T}), T -> 'T', A -> :(parent($A))), + (T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A)))) + +for op in (:(+), :(-)) + for (wrapa, transa, unwrapa) in op_wrappers, (wrapb, transb, unwrapb) in op_wrappers + TypeA = wrapa(:(CuMatrix{T})) + TypeB = wrapb(:(CuMatrix{T})) + @eval Base.$op(A::$TypeA, B::$TypeB) where {T <: CublasFloat} = CUBLAS.geam($transa(T), $transb(T), one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B))) + end +end diff --git a/test/cublas.jl b/test/cublas.jl index 0d7c01ba22..ea0019a67c 100644 --- a/test/cublas.jl +++ b/test/cublas.jl @@ -1198,6 +1198,27 @@ end h_C = Array(d_C) @test D ≈ h_C end + @testset "CuMatrix -- A ± B -- $elty" begin + for opa in (identity, transpose, adjoint) + for opb in (identity, transpose, adjoint) + n = 10 + m = 20 + geam_A = opa == identity ? rand(elty, n, m) : rand(elty, m, n) + geam_B = opb == identity ? rand(elty, n, m) : rand(elty, m, n) + + geam_dA = CuMatrix{elty}(geam_A) + geam_dB = CuMatrix{elty}(geam_B) + + geam_C = opa(geam_A) + opb(geam_B) + geam_dC = opa(geam_dA) + opb(geam_dB) + @test geam_C ≈ collect(geam_dC) + + geam_C = opa(geam_A) - opb(geam_B) + geam_dC = opa(geam_dA) - opb(geam_dB) + @test geam_C ≈ collect(geam_dC) + end + end + end A = rand(elty,m,k) d_A = CuArray(A) @testset "syrkx!" begin