diff --git a/src/array.jl b/src/array.jl index d28a4dbc3..25e57f4ed 100644 --- a/src/array.jl +++ b/src/array.jl @@ -58,35 +58,27 @@ end mutable struct ROCArray{T,N} <: AbstractGPUArray{T,N} buf::Mem.Buffer - own::Bool - dims::Dims{N} offset::Int - syncstate::Runtime.SyncState - function ROCArray{T,N}(buf::Mem.Buffer, dims::Dims{N}; offset::Integer=0, own::Bool=true) where {T,N} + function ROCArray{T,N}(buf::Mem.Buffer, dims::Dims{N}; offset::Integer=0) where {T,N} @assert isbitstype(T) "ROCArray only supports bits types" - xs = new{T,N}(buf, own, dims, offset, Runtime.SyncState()) - if own - hsaref!() - end + xs = new{T,N}(buf, dims, offset, Runtime.SyncState()) Mem.retain(buf) - finalizer(unsafe_free!, xs) + finalizer(_safe_free!, xs) return xs end end -function unsafe_free!(xs::ROCArray) - if Mem.release(xs.buf) - Mem.free(xs.buf) - if xs.own - hsaunref!() - end - end +_safe_free!(xs::ROCArray) = _safe_free!(xs.buf) +function _safe_free!(buf::Mem.Buffer) + Mem.release(buf) return end +unsafe_free!(xs::ROCArray) = Mem.free_if_live(xs.buf) + wait!(x::ROCArray) = wait!(x.syncstate) mark!(x::ROCArray, s) = mark!(x.syncstate, s) wait!(xs::Vector{<:ROCArray}) = foreach(wait!, xs) @@ -231,8 +223,8 @@ function Base.unsafe_wrap(::Type{<:ROCArray}, ptr::Ptr{T}, dims::NTuple{N,<:Inte @assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray" sz = prod(dims) * sizeof(T) device_ptr = lock ? Mem.lock(ptr, sz, device) : ptr - buf = Mem.Buffer(device_ptr, ptr, device_ptr, sz, device, false, false) - return ROCArray{T, N}(buf, dims; own=false) + buf = Mem.Buffer(device_ptr, Ptr{Cvoid}(ptr), device_ptr, sz, device, false, false) + return ROCArray{T, N}(buf, dims) end Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T = unsafe_wrap(ROCArray, Base.unsafe_convert(Ptr{T}, ptr), dims; kwargs...) @@ -280,11 +272,7 @@ end end @inline function unsafe_contiguous_view(a::ROCArray{T}, I::NTuple{N,Base.ViewIndex}, dims::NTuple{M,Integer}) where {T,N,M} offset = Base.compute_offset1(a, 1, I) * sizeof(T) - - Mem.retain(a.buf) - b = ROCArray{T,M}(a.buf, dims, offset=a.offset+offset, own=false) - finalizer(unsafe_free!, b) - return b + ROCArray{T,M}(a.buf, dims, offset=a.offset+offset) end @inline function unsafe_view(A, I, ::NonContiguous) @@ -306,10 +294,7 @@ function Base.reshape(a::ROCArray{T,M}, dims::NTuple{N,Int}) where {T,N,M} return a end - Mem.retain(a.buf) - b = ROCArray{T,N}(a.buf, dims, offset=a.offset, own=false) - finalizer(unsafe_free!, b) - return b + ROCArray{T,N}(a.buf, dims, offset=a.offset) end @@ -401,10 +386,7 @@ zeros(T::Type, dims...) = fill!(ROCArray{T}(undef, dims...), zero(T)) # create a derived array (reinterpreted or reshaped) that's still a ROCArray # TODO: Move this to GPUArrays? @inline function _derived_array(::Type{T}, N::Int, a::ROCArray, osize::Dims) where {T} - Mem.retain(a.buf) - b = ROCArray{T,N}(a.buf, osize, offset=a.offset, own=false) - finalizer(unsafe_free!, b) - return b + return ROCArray{T,N}(a.buf, osize, offset=a.offset) end ## reinterpret @@ -523,14 +505,18 @@ end """ resize!(a::ROCVector, n::Integer) -Resize `a` to contain `n` elements. If `n` is smaller than the current collection length, -the first `n` elements will be retained. If `n` is larger, the new elements are not -guaranteed to be initialized. +Resize `a` to contain `n` elements. If `n` is smaller than the current +collection length, the first `n` elements will be retained. If `n` is larger, +the new elements are not guaranteed to be initialized. -Note that this operation is only supported on managed buffers, i.e., not on arrays that are -created by `unsafe_wrap` with `own=false`. +Note that this operation is only supported on managed buffers, i.e., not on +arrays that are created by `unsafe_wrap`. """ function Base.resize!(A::ROCVector{T}, n::Integer) where T + if A.buf.host_ptr != C_NULL + throw(ArgumentError("Cannot resize an unowned `ROCVector`")) + end + # TODO: add additional space to allow for quicker resizing if n == length(A) return A @@ -547,10 +533,12 @@ function Base.resize!(A::ROCVector{T}, n::Integer) where T if copy_size > 0 Mem.transfer!(new_buf, A.buf, copy_size) end + + # Release old buffer + _safe_free!(A.buf) + # N.B. Manually retain new buffer (this is normally done in ROCArray ctor) Mem.retain(new_buf) - if Mem.release(A.buf) - Mem.free(A.buf) - end + A.buf = new_buf A.dims = (n,) A.offset = 0 diff --git a/src/runtime/error.jl b/src/runtime/error.jl index d72d3b951..e97d371f2 100644 --- a/src/runtime/error.jl +++ b/src/runtime/error.jl @@ -11,9 +11,10 @@ end Gets the string description of an error code. """ -function description(err::HSAError) +description(err::HSAError) = description(err.code) +function description(status::HSA.Status) str_ref = Ref{Ptr{Int8}}() - HSA.status_string(err.code, str_ref) + HSA.status_string(status, str_ref) unsafe_string(reinterpret(Cstring, str_ref[])) end diff --git a/src/runtime/kernel-signal.jl b/src/runtime/kernel-signal.jl index cb0a14e5d..957d3f012 100644 --- a/src/runtime/kernel-signal.jl +++ b/src/runtime/kernel-signal.jl @@ -83,7 +83,7 @@ end function Base.show(io::IO, kersig::ROCKernelSignal) ex = kersig.exception - print(io, "ROCKernelSignal(signal=$(kersig.signal), $(ex !== nothing ? ", exception=$ex" : ""))") + print(io, "ROCKernelSignal(signal=$(kersig.signal)$(ex !== nothing ? ", exception=$ex" : ""))") end Base.notify(kersig::ROCKernelSignal) = notify(kersig.signal) diff --git a/src/runtime/memory.jl b/src/runtime/memory.jl index 9576e3063..26dd62b88 100644 --- a/src/runtime/memory.jl +++ b/src/runtime/memory.jl @@ -4,7 +4,7 @@ import AMDGPU import AMDGPU: HSA if AMDGPU.hip_configured -import AMDGPU: HIP + import AMDGPU: HIP end import AMDGPU: Runtime import .Runtime: ROCDevice, ROCSignal, ROCMemoryRegion, ROCMemoryPool, ROCDim, ROCDim3 @@ -21,6 +21,16 @@ struct Buffer device::ROCDevice coherent::Bool pool_alloc::Bool + # Unique ID used for refcounting. + _id::UInt64 + + function Buffer( + ptr::Ptr{Cvoid}, host_ptr::Ptr{Cvoid}, base_ptr::Ptr{Cvoid}, + bytesize::Int, device::ROCDevice, coherent::Bool, pool_alloc::Bool, + ) + _id = _buffer_id!() + new(ptr, host_ptr, base_ptr, bytesize, device, coherent, pool_alloc, _id) + end end Base.unsafe_convert(::Type{Ptr{T}}, buf::Buffer) where {T} = convert(Ptr{T}, buf.ptr) @@ -35,13 +45,19 @@ end ## refcounting -const refcounts = Dict{Ptr{Cvoid}, Threads.Atomic{Int}}() +const _ID_COUNTER = Threads.Atomic{UInt64}(0) +const refcounts = Dict{UInt64, Int}() +const liveness = Dict{UInt64, Bool}() const refcounts_lock = Threads.ReentrantLock() +function _buffer_id!()::UInt64 + return Threads.atomic_add!(_ID_COUNTER, UInt64(1)) +end + function refcount(buf::Buffer) Base.lock(refcounts_lock) do - get(refcounts, buf.base_ptr, Threads.Atomic{Int}(0)) - end[] + get(refcounts, buf._id, 0) + end end """ @@ -50,10 +66,12 @@ end Increase the refcount of a buffer. """ function retain(buf::Buffer) - count = Base.lock(refcounts_lock) do - get!(refcounts, buf.base_ptr, Threads.Atomic{Int}(0)) + Base.lock(refcounts_lock) do + live = get!(liveness, buf._id, true) + @assert live "Trying to retain dead buffer!" + count = get!(refcounts, buf._id, 0) + refcounts[buf._id] = count + 1 end - Threads.atomic_add!(count, 1) return end @@ -65,24 +83,51 @@ to 0, and some action needs to be taken. """ function release(buf::Buffer) while !Base.trylock(refcounts_lock) end - count = try - get(refcounts, buf.base_ptr, Threads.Atomic{Int}(0)) + try + count = refcounts[buf._id] + @assert count >= 1 "Buffer refcount dropping below 0!" + refcounts[buf._id] = count - 1 + done = count == 1 + + live = liveness[buf._id] + + if done + if live + free(buf) + end + untrack(buf) + end + return done finally Base.unlock(refcounts_lock) end - old = Threads.atomic_sub!(count, 1) - return (old == 1) +end + +""" + free_if_live(buf::Buffer) + +Frees the base pointer for `buf` if it is still live (not yet freed). Does not +update refcounts. +""" +function free_if_live(buf::Buffer) + Base.lock(refcounts_lock) do + if liveness[buf._id] + liveness[buf._id] = false + free(buf) + end + end end """ untrack(buf::Buffer) -Removing refcount tracking information for a buffer. +Removes refcount tracking information for a buffer. """ function untrack(buf::Buffer) while !Base.trylock(refcounts_lock) end try - delete!(refcounts, buf.base_ptr) + delete!(liveness, buf._id) + delete!(refcounts, buf._id) finally Base.unlock(refcounts_lock) end @@ -336,14 +381,18 @@ function alloc(device::ROCDevice, pool::ROCMemoryPool, bytesize::Integer) alloc_or_retry!() do HSA.amd_memory_pool_allocate(pool.pool, bytesize, 0, ptr_ref) end - return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, device, Runtime.pool_accessible_by_all(pool), true) + AMDGPU.hsaref!() + ptr = ptr_ref[] + return Buffer(ptr, C_NULL, ptr, bytesize, device, Runtime.pool_accessible_by_all(pool), true) end function alloc(device::ROCDevice, region::ROCMemoryRegion, bytesize::Integer) ptr_ref = Ref{Ptr{Cvoid}}() alloc_or_retry!() do HSA.memory_allocate(region.region, bytesize, ptr_ref) end - return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, device, Runtime.region_host_accessible(region), false) + AMDGPU.hsaref!() + ptr = ptr_ref[] + return Buffer(ptr, C_NULL, ptr, bytesize, device, Runtime.region_host_accessible(region), false) end alloc(bytesize; kwargs...) = alloc(Runtime.get_default_device(), bytesize; kwargs...) @@ -360,34 +409,48 @@ function alloc_hip(bytesize::Integer) HSA.STATUS_ERROR_OUT_OF_RESOURCES end end - return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, Runtime.get_default_device(), false, true) + AMDGPU.hsaref!() + ptr = ptr_ref[] + return Buffer(ptr, C_NULL, ptr, bytesize, Runtime.get_default_device(), false, true) end end # if AMDGPU.hip_configured function free(buf::Buffer) - if buf.ptr != C_NULL - if buf.host_ptr == C_NULL - # HSA-backed - if buf.pool_alloc - if USE_HIP_MALLOC_OVERRIDE - @static if AMDGPU.hip_configured - # Actually HIP-backed - HIP.@check HIP.hipFree(buf.ptr) - end - else - check(HSA.amd_memory_pool_free(buf.ptr)) + buf.ptr == C_NULL && return + + if buf.host_ptr == C_NULL + # HSA-backed + if buf.pool_alloc + if USE_HIP_MALLOC_OVERRIDE + @static if AMDGPU.hip_configured + # Actually HIP-backed + HIP.@check HIP.hipFree(buf.base_ptr) end else - check(HSA.memory_free(buf.ptr)) + memory_check(HSA.amd_memory_pool_free(buf.base_ptr), buf.base_ptr) end else - # Wrapped - unlock(buf.ptr) + memory_check(HSA.memory_free(buf.base_ptr), buf.base_ptr) end - untrack(buf) + AMDGPU.hsaunref!() + else + # Wrapped + unlock(buf.ptr) end return end +# N.B. We try to keep this from yielding or throwing, since this usually runs +# in a finalizer +function memory_check(status::HSA.Status, ptr::Ptr{Cvoid}) + if status != HSA.STATUS_SUCCESS + err_str = Runtime.description(status) + Core.println("Error when attempting to free an HSA buffer:\n $err_str") + pinfo = pointerinfo(ptr) + Core.println(sprint(io->Base.show(io, pinfo))) + return false + end + return true +end struct PoolAllocation addr::Ptr{Cvoid} diff --git a/src/runtime/signal.jl b/src/runtime/signal.jl index 2330e732c..558c240f8 100644 --- a/src/runtime/signal.jl +++ b/src/runtime/signal.jl @@ -64,6 +64,7 @@ function Base.wait( finished = false while !finished + @assert AMDGPU.HSA_REFCOUNT[] > 0 v = HSA.signal_wait_scacquire( signal.signal[], HSA.SIGNAL_CONDITION_LT, 1, min_latency, HSA.WAIT_STATE_BLOCKED) diff --git a/src/runtime/sync.jl b/src/runtime/sync.jl index 6ca8fa763..b97f21dad 100644 --- a/src/runtime/sync.jl +++ b/src/runtime/sync.jl @@ -4,8 +4,12 @@ import ..AMDGPU: hip_configured struct SyncState signals::Vector{ROCKernelSignal} streams::Vector{Ptr{Cvoid}} + lock::Threads.ReentrantLock end -SyncState() = SyncState(ROCKernelSignal[], Ptr{Cvoid}[]) +SyncState() = + SyncState(ROCKernelSignal[], + Ptr{Cvoid}[], + Threads.ReentrantLock()) struct WaitAdaptor end struct MarkAdaptor{S} @@ -13,20 +17,24 @@ struct MarkAdaptor{S} end function wait!(ss::SyncState) - # FIXME: Use barrier_and on dedicated queue - foreach(wait, ss.signals) - empty!(ss.signals) - @static if hip_configured - for s in ss.streams - AMDGPU.HIP.@check AMDGPU.HIP.hipStreamSynchronize(s) + lock(ss.lock) do + # FIXME: Use barrier_and on dedicated queue + foreach(wait, ss.signals) + empty!(ss.signals) + @static if hip_configured + for s in ss.streams + AMDGPU.HIP.@check AMDGPU.HIP.hipStreamSynchronize(s) + end + empty!(ss.streams) + AMDGPU.HIP.@check AMDGPU.HIP.hipStreamSynchronize(C_NULL) # FIXME: This shouldn't be necessary end - empty!(ss.streams) - AMDGPU.HIP.@check AMDGPU.HIP.hipStreamSynchronize(C_NULL) # FIXME: This shouldn't be necessary end - nothing + return end -mark!(ss::SyncState, signal::ROCKernelSignal) = push!(ss.signals, signal) -mark!(ss::SyncState, stream::Ptr{Cvoid}) = push!(ss.streams, stream) +mark!(ss::SyncState, signal::ROCKernelSignal) = + lock(()->push!(ss.signals, signal), ss.lock) +mark!(ss::SyncState, stream::Ptr{Cvoid}) = + lock(()->push!(ss.streams, stream), ss.lock) wait!(x) = Adapt.adapt(WaitAdaptor(), x) mark!(x, s) = Adapt.adapt(MarkAdaptor(s), x) diff --git a/test/rocarray/base.jl b/test/rocarray/base.jl index c9cc74471..7a91f88f7 100644 --- a/test/rocarray/base.jl +++ b/test/rocarray/base.jl @@ -96,4 +96,88 @@ end finalize(A) end +@testset "Refcounting" begin + refcount_live(A) = (get(AMDGPU.Mem.refcounts, A.buf._id, 0), + get(AMDGPU.Mem.liveness, A.buf._id, false)) + + for (f, switch) in [(A->view(A, 2:4), false), + (A->resize!(A, 8), true), + (A->reinterpret(UInt8, A), false), + (A->reshape(A, 4, 4), false)] + + # Safe free + A = AMDGPU.ones(16) + @test refcount_live(A) == (1, true) + B = f(A) + @test A.buf.base_ptr == B.buf.base_ptr + @test refcount_live(A) == refcount_live(B) + @test refcount_live(B) == (2-switch, true) + finalize(B) + @test refcount_live(B) == (1-switch, !switch) + finalize(A) + @test refcount_live(B) == (0, false) + + # Unsafe free original + A = AMDGPU.ones(16) + B = f(A) + AMDGPU.unsafe_free!(A) + @test refcount_live(B) == (2-switch, false) + finalize(B) + @test refcount_live(B) == (1-switch, false) + finalize(A) + @test refcount_live(B) == (0, false) + + # Unsafe free derived + A = AMDGPU.ones(16) + B = f(A) + AMDGPU.unsafe_free!(B) + @test refcount_live(B) == (2-switch, false) + finalize(A) + @test refcount_live(B) == (1-switch, false) + finalize(B) + @test refcount_live(B) == (0, false) + + # Unsafe free original and derived + A = AMDGPU.ones(16) + B = f(A) + AMDGPU.unsafe_free!(A) + AMDGPU.unsafe_free!(B) + @test refcount_live(B) == (2-switch, false) + finalize(A) + @test refcount_live(B) == (1-switch, false) + finalize(B) + @test refcount_live(B) == (0, false) + end + + # Chained Safe free + A = AMDGPU.ones(16) + @test refcount_live(A) == (1, true) + B = reshape(A, 4, 4) + @test refcount_live(A) == (2, true) + C = reshape(B, 2, 8) + @test refcount_live(A) == (3, true) + finalize(B) + @test refcount_live(A) == (2, true) + finalize(A) + @test refcount_live(A) == (1, true) + finalize(C) + @test refcount_live(A) == (0, false) + + # Chained Unsafe free + A = AMDGPU.ones(16) + @test refcount_live(A) == (1, true) + B = reshape(A, 4, 4) + @test refcount_live(A) == (2, true) + C = reshape(B, 2, 8) + @test refcount_live(A) == (3, true) + AMDGPU.unsafe_free!(A) + @test refcount_live(A) == (3, false) + finalize(B) + @test refcount_live(A) == (2, false) + finalize(A) + @test refcount_live(A) == (1, false) + finalize(C) + @test refcount_live(A) == (0, false) +end + end