Skip to content

Commit

Permalink
Merge pull request #353 from JuliaGPU/jps/unsafe_free-for-real
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Jan 21, 2023
2 parents 191afc4 + 3f58e5e commit 09cf06c
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 85 deletions.
66 changes: 27 additions & 39 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,27 @@ end

mutable struct ROCArray{T,N} <: AbstractGPUArray{T,N}
buf::Mem.Buffer
own::Bool

dims::Dims{N}
offset::Int

syncstate::Runtime.SyncState

function ROCArray{T,N}(buf::Mem.Buffer, dims::Dims{N}; offset::Integer=0, own::Bool=true) where {T,N}
function ROCArray{T,N}(buf::Mem.Buffer, dims::Dims{N}; offset::Integer=0) where {T,N}
@assert isbitstype(T) "ROCArray only supports bits types"
xs = new{T,N}(buf, own, dims, offset, Runtime.SyncState())
if own
hsaref!()
end
xs = new{T,N}(buf, dims, offset, Runtime.SyncState())
Mem.retain(buf)
finalizer(unsafe_free!, xs)
finalizer(_safe_free!, xs)
return xs
end
end

function unsafe_free!(xs::ROCArray)
if Mem.release(xs.buf)
Mem.free(xs.buf)
if xs.own
hsaunref!()
end
end
_safe_free!(xs::ROCArray) = _safe_free!(xs.buf)
function _safe_free!(buf::Mem.Buffer)
Mem.release(buf)
return
end

unsafe_free!(xs::ROCArray) = Mem.free_if_live(xs.buf)

wait!(x::ROCArray) = wait!(x.syncstate)
mark!(x::ROCArray, s) = mark!(x.syncstate, s)
wait!(xs::Vector{<:ROCArray}) = foreach(wait!, xs)
Expand Down Expand Up @@ -231,8 +223,8 @@ function Base.unsafe_wrap(::Type{<:ROCArray}, ptr::Ptr{T}, dims::NTuple{N,<:Inte
@assert isbitstype(T) "Cannot wrap a non-bitstype pointer as a ROCArray"
sz = prod(dims) * sizeof(T)
device_ptr = lock ? Mem.lock(ptr, sz, device) : ptr
buf = Mem.Buffer(device_ptr, ptr, device_ptr, sz, device, false, false)
return ROCArray{T, N}(buf, dims; own=false)
buf = Mem.Buffer(device_ptr, Ptr{Cvoid}(ptr), device_ptr, sz, device, false, false)
return ROCArray{T, N}(buf, dims)
end
Base.unsafe_wrap(::Type{ROCArray{T}}, ptr::Ptr, dims; kwargs...) where T =
unsafe_wrap(ROCArray, Base.unsafe_convert(Ptr{T}, ptr), dims; kwargs...)
Expand Down Expand Up @@ -280,11 +272,7 @@ end
end
@inline function unsafe_contiguous_view(a::ROCArray{T}, I::NTuple{N,Base.ViewIndex}, dims::NTuple{M,Integer}) where {T,N,M}
offset = Base.compute_offset1(a, 1, I) * sizeof(T)

Mem.retain(a.buf)
b = ROCArray{T,M}(a.buf, dims, offset=a.offset+offset, own=false)
finalizer(unsafe_free!, b)
return b
ROCArray{T,M}(a.buf, dims, offset=a.offset+offset)
end

@inline function unsafe_view(A, I, ::NonContiguous)
Expand All @@ -306,10 +294,7 @@ function Base.reshape(a::ROCArray{T,M}, dims::NTuple{N,Int}) where {T,N,M}
return a
end

Mem.retain(a.buf)
b = ROCArray{T,N}(a.buf, dims, offset=a.offset, own=false)
finalizer(unsafe_free!, b)
return b
ROCArray{T,N}(a.buf, dims, offset=a.offset)
end


Expand Down Expand Up @@ -401,10 +386,7 @@ zeros(T::Type, dims...) = fill!(ROCArray{T}(undef, dims...), zero(T))
# create a derived array (reinterpreted or reshaped) that's still a ROCArray
# TODO: Move this to GPUArrays?
@inline function _derived_array(::Type{T}, N::Int, a::ROCArray, osize::Dims) where {T}
Mem.retain(a.buf)
b = ROCArray{T,N}(a.buf, osize, offset=a.offset, own=false)
finalizer(unsafe_free!, b)
return b
return ROCArray{T,N}(a.buf, osize, offset=a.offset)
end

## reinterpret
Expand Down Expand Up @@ -523,14 +505,18 @@ end
"""
resize!(a::ROCVector, n::Integer)
Resize `a` to contain `n` elements. If `n` is smaller than the current collection length,
the first `n` elements will be retained. If `n` is larger, the new elements are not
guaranteed to be initialized.
Resize `a` to contain `n` elements. If `n` is smaller than the current
collection length, the first `n` elements will be retained. If `n` is larger,
the new elements are not guaranteed to be initialized.
Note that this operation is only supported on managed buffers, i.e., not on arrays that are
created by `unsafe_wrap` with `own=false`.
Note that this operation is only supported on managed buffers, i.e., not on
arrays that are created by `unsafe_wrap`.
"""
function Base.resize!(A::ROCVector{T}, n::Integer) where T
if A.buf.host_ptr != C_NULL
throw(ArgumentError("Cannot resize an unowned `ROCVector`"))
end

# TODO: add additional space to allow for quicker resizing
if n == length(A)
return A
Expand All @@ -547,10 +533,12 @@ function Base.resize!(A::ROCVector{T}, n::Integer) where T
if copy_size > 0
Mem.transfer!(new_buf, A.buf, copy_size)
end

# Release old buffer
_safe_free!(A.buf)
# N.B. Manually retain new buffer (this is normally done in ROCArray ctor)
Mem.retain(new_buf)
if Mem.release(A.buf)
Mem.free(A.buf)
end

A.buf = new_buf
A.dims = (n,)
A.offset = 0
Expand Down
5 changes: 3 additions & 2 deletions src/runtime/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ end
Gets the string description of an error code.
"""
function description(err::HSAError)
description(err::HSAError) = description(err.code)
function description(status::HSA.Status)
str_ref = Ref{Ptr{Int8}}()
HSA.status_string(err.code, str_ref)
HSA.status_string(status, str_ref)
unsafe_string(reinterpret(Cstring, str_ref[]))
end

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/kernel-signal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end

function Base.show(io::IO, kersig::ROCKernelSignal)
ex = kersig.exception
print(io, "ROCKernelSignal(signal=$(kersig.signal), $(ex !== nothing ? ", exception=$ex" : ""))")
print(io, "ROCKernelSignal(signal=$(kersig.signal)$(ex !== nothing ? ", exception=$ex" : ""))")
end

Base.notify(kersig::ROCKernelSignal) = notify(kersig.signal)
125 changes: 94 additions & 31 deletions src/runtime/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import AMDGPU
import AMDGPU: HSA
if AMDGPU.hip_configured
import AMDGPU: HIP
import AMDGPU: HIP
end
import AMDGPU: Runtime
import .Runtime: ROCDevice, ROCSignal, ROCMemoryRegion, ROCMemoryPool, ROCDim, ROCDim3
Expand All @@ -21,6 +21,16 @@ struct Buffer
device::ROCDevice
coherent::Bool
pool_alloc::Bool
# Unique ID used for refcounting.
_id::UInt64

function Buffer(
ptr::Ptr{Cvoid}, host_ptr::Ptr{Cvoid}, base_ptr::Ptr{Cvoid},
bytesize::Int, device::ROCDevice, coherent::Bool, pool_alloc::Bool,
)
_id = _buffer_id!()
new(ptr, host_ptr, base_ptr, bytesize, device, coherent, pool_alloc, _id)
end
end

Base.unsafe_convert(::Type{Ptr{T}}, buf::Buffer) where {T} = convert(Ptr{T}, buf.ptr)
Expand All @@ -35,13 +45,19 @@ end

## refcounting

const refcounts = Dict{Ptr{Cvoid}, Threads.Atomic{Int}}()
const _ID_COUNTER = Threads.Atomic{UInt64}(0)
const refcounts = Dict{UInt64, Int}()
const liveness = Dict{UInt64, Bool}()
const refcounts_lock = Threads.ReentrantLock()

function _buffer_id!()::UInt64
return Threads.atomic_add!(_ID_COUNTER, UInt64(1))
end

function refcount(buf::Buffer)
Base.lock(refcounts_lock) do
get(refcounts, buf.base_ptr, Threads.Atomic{Int}(0))
end[]
get(refcounts, buf._id, 0)
end
end

"""
Expand All @@ -50,10 +66,12 @@ end
Increase the refcount of a buffer.
"""
function retain(buf::Buffer)
count = Base.lock(refcounts_lock) do
get!(refcounts, buf.base_ptr, Threads.Atomic{Int}(0))
Base.lock(refcounts_lock) do
live = get!(liveness, buf._id, true)
@assert live "Trying to retain dead buffer!"
count = get!(refcounts, buf._id, 0)
refcounts[buf._id] = count + 1
end
Threads.atomic_add!(count, 1)
return
end

Expand All @@ -65,24 +83,51 @@ to 0, and some action needs to be taken.
"""
function release(buf::Buffer)
while !Base.trylock(refcounts_lock) end
count = try
get(refcounts, buf.base_ptr, Threads.Atomic{Int}(0))
try
count = refcounts[buf._id]
@assert count >= 1 "Buffer refcount dropping below 0!"
refcounts[buf._id] = count - 1
done = count == 1

live = liveness[buf._id]

if done
if live
free(buf)
end
untrack(buf)
end
return done
finally
Base.unlock(refcounts_lock)
end
old = Threads.atomic_sub!(count, 1)
return (old == 1)
end

"""
free_if_live(buf::Buffer)
Frees the base pointer for `buf` if it is still live (not yet freed). Does not
update refcounts.
"""
function free_if_live(buf::Buffer)
Base.lock(refcounts_lock) do
if liveness[buf._id]
liveness[buf._id] = false
free(buf)
end
end
end

"""
untrack(buf::Buffer)
Removing refcount tracking information for a buffer.
Removes refcount tracking information for a buffer.
"""
function untrack(buf::Buffer)
while !Base.trylock(refcounts_lock) end
try
delete!(refcounts, buf.base_ptr)
delete!(liveness, buf._id)
delete!(refcounts, buf._id)
finally
Base.unlock(refcounts_lock)
end
Expand Down Expand Up @@ -336,14 +381,18 @@ function alloc(device::ROCDevice, pool::ROCMemoryPool, bytesize::Integer)
alloc_or_retry!() do
HSA.amd_memory_pool_allocate(pool.pool, bytesize, 0, ptr_ref)
end
return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, device, Runtime.pool_accessible_by_all(pool), true)
AMDGPU.hsaref!()
ptr = ptr_ref[]
return Buffer(ptr, C_NULL, ptr, bytesize, device, Runtime.pool_accessible_by_all(pool), true)
end
function alloc(device::ROCDevice, region::ROCMemoryRegion, bytesize::Integer)
ptr_ref = Ref{Ptr{Cvoid}}()
alloc_or_retry!() do
HSA.memory_allocate(region.region, bytesize, ptr_ref)
end
return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, device, Runtime.region_host_accessible(region), false)
AMDGPU.hsaref!()
ptr = ptr_ref[]
return Buffer(ptr, C_NULL, ptr, bytesize, device, Runtime.region_host_accessible(region), false)
end
alloc(bytesize; kwargs...) =
alloc(Runtime.get_default_device(), bytesize; kwargs...)
Expand All @@ -360,34 +409,48 @@ function alloc_hip(bytesize::Integer)
HSA.STATUS_ERROR_OUT_OF_RESOURCES
end
end
return Buffer(ptr_ref[], C_NULL, ptr_ref[], bytesize, Runtime.get_default_device(), false, true)
AMDGPU.hsaref!()
ptr = ptr_ref[]
return Buffer(ptr, C_NULL, ptr, bytesize, Runtime.get_default_device(), false, true)
end
end # if AMDGPU.hip_configured

function free(buf::Buffer)
if buf.ptr != C_NULL
if buf.host_ptr == C_NULL
# HSA-backed
if buf.pool_alloc
if USE_HIP_MALLOC_OVERRIDE
@static if AMDGPU.hip_configured
# Actually HIP-backed
HIP.@check HIP.hipFree(buf.ptr)
end
else
check(HSA.amd_memory_pool_free(buf.ptr))
buf.ptr == C_NULL && return

if buf.host_ptr == C_NULL
# HSA-backed
if buf.pool_alloc
if USE_HIP_MALLOC_OVERRIDE
@static if AMDGPU.hip_configured
# Actually HIP-backed
HIP.@check HIP.hipFree(buf.base_ptr)
end
else
check(HSA.memory_free(buf.ptr))
memory_check(HSA.amd_memory_pool_free(buf.base_ptr), buf.base_ptr)
end
else
# Wrapped
unlock(buf.ptr)
memory_check(HSA.memory_free(buf.base_ptr), buf.base_ptr)
end
untrack(buf)
AMDGPU.hsaunref!()
else
# Wrapped
unlock(buf.ptr)
end
return
end
# N.B. We try to keep this from yielding or throwing, since this usually runs
# in a finalizer
function memory_check(status::HSA.Status, ptr::Ptr{Cvoid})
if status != HSA.STATUS_SUCCESS
err_str = Runtime.description(status)
Core.println("Error when attempting to free an HSA buffer:\n $err_str")
pinfo = pointerinfo(ptr)
Core.println(sprint(io->Base.show(io, pinfo)))
return false
end
return true
end

struct PoolAllocation
addr::Ptr{Cvoid}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/signal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function Base.wait(
finished = false

while !finished
@assert AMDGPU.HSA_REFCOUNT[] > 0
v = HSA.signal_wait_scacquire(
signal.signal[], HSA.SIGNAL_CONDITION_LT, 1,
min_latency, HSA.WAIT_STATE_BLOCKED)
Expand Down
Loading

0 comments on commit 09cf06c

Please sign in to comment.