From d72ce7d9943dba39d088e010f88d2657eb4d6fd1 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Fri, 27 Sep 2024 15:53:15 -0300 Subject: [PATCH] Use CPU copy with SharedStorage [skip tests] --- src/array.jl | 18 ++++++++++++++++++ test/array.jl | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/array.jl b/src/array.jl index 3054084f..8dc200e6 100644 --- a/src/array.jl +++ b/src/array.jl @@ -401,6 +401,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::Arra end return dest end +function Base.unsafe_copyto!(::MTLDevice, dest::MtlArray{T,<:Any,Metal.SharedStorage}, doffs, src::Array{T}, soffs, n) where T + # these copies are implemented using pure memcpy's, not API calls, so aren't ordered. + synchronize() + GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array,dest), doffs), pointer(src, soffs), n) + return dest +end # GPU -> CPU function Base.unsafe_copyto!(dev::MTLDevice, dest::Array{T}, doffs, src::MtlArray{T}, soffs, n) where T @@ -414,6 +420,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::Array{T}, doffs, src::MtlArra end return dest end +function Base.unsafe_copyto!(::MTLDevice, dest::Array{T}, doffs, src::MtlArray{T,<:Any,Metal.SharedStorage}, soffs, n) where T + # these copies are implemented using pure memcpy's, not API calls, so aren't ordered. + synchronize() + GC.@preserve src dest unsafe_copyto!(pointer(dest, doffs), pointer(unsafe_wrap(Array,src), soffs), n) + return dest +end # GPU -> GPU function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlArray{T}, soffs, n) where T @@ -427,6 +439,12 @@ function Base.unsafe_copyto!(dev::MTLDevice, dest::MtlArray{T}, doffs, src::MtlA end return dest end +function Base.unsafe_copyto!(::MTLDevice, dest::MtlArray{T,<:Any,Metal.SharedStorage}, doffs, src::MtlArray{T,<:Any,Metal.SharedStorage}, soffs, n) where T + # these copies are implemented using pure memcpy's, not API calls, so aren't ordered. + synchronize() + GC.@preserve src dest unsafe_copyto!(pointer(unsafe_wrap(Array,dest), doffs), pointer(unsafe_wrap(Array,src), soffs), n) + return dest +end ## regular gpu array adaptor diff --git a/test/array.jl b/test/array.jl index 3333f5da..05dbfb5e 100644 --- a/test/array.jl +++ b/test/array.jl @@ -69,6 +69,29 @@ end @test collect(Metal.fill(1, 2, 2)) == ones(Float32, 2, 2) end +@testset "copyto!: $T, $S" for S in [Metal.PrivateStorage, Metal.SharedStorage], T in [Float16, Float32, Bool, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8] + function testcopyto!(out, in) + copyto!(out,in) + return Array(in) == Array(out) + end + + dim = (1000,17,10) + A = rand(T,dim) + mtlA = mtl(A;storage=S) + + #cpu -> gpu + res = Metal.zeros(T,dim;storage=S) + @test testcopyto!(res,A) + + #gpu -> cpu + res = zeros(T,dim) + @test testcopyto!(res,mtlA) + + #gpu -> gpu + res = Metal.zeros(T,dim;storage=S) + @test testcopyto!(res,mtlA) +end + check_storagemode(arr, smode) = Metal.storagemode(arr) == smode # There is some repetition to the GPUArrays tests to test for different storagemodes