Skip to content

Commit

Permalink
Add memory record
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Nov 19, 2024
1 parent 98e7084 commit a8bb216
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,49 @@
const MemoryRecords = LockedObject(Dict{UInt64, Any}())

# TODO make TLS
const RECORD_MEMORY::Ref{Bool} = Ref(false)

function record_memory!(rec::Bool; free::Bool = true, sync::Bool = false)
RECORD_MEMORY[] = rec
if !rec
free && free_records!(; sync)
end
return
end

record_memory() = RECORD_MEMORY[]

function record!(x)
# @info "Recording $(typeof(x)) $(size(x))"
Base.lock(records -> records[_hash(x)] = x, MemoryRecords)
return
end

function free_records!(; sync::Bool = false)
Base.lock(MemoryRecords) do records
# @info "Freeing `$(length(records))` records"
for (k, x) in records
unsafe_free!(x)
end
empty!(records)
end
sync && AMDGPU.synchronize()
return
end

function remove_record!(x)
record_memory() || return

k = _hash(x)
Base.lock(MemoryRecords) do records
if k in records.keys
# @info "Removing record"
pop!(records, k)
end
end
return
end

mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
buf::DataRef{Managed{B}}
dims::Dims{N}
Expand All @@ -23,6 +69,18 @@ mutable struct ROCArray{T, N, B} <: AbstractGPUArray{T, N}
end
end

function _hash(x::ROCArray)
# @info "_hash"
# @show x.buf.rc.obj.mem.ptr
# @show x.offset
# @show x.dims
r = Base.hash(x.buf.rc.obj.mem.ptr,
Base.hash(x.offset,
Base.hash(x.dims)))
# @show r
return r
end

GPUArrays.storage(a::ROCArray) = a.buf

function GPUArrays.derive(
Expand Down
4 changes: 4 additions & 0 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ end

function pool_alloc(::Type{B}, bytesize) where B
s = AMDGPU.stream()
# @info "[pool_alloc] $(Base.format_bytes(bytesize))"
# display(stacktrace()); println()
# println()
# println()
Managed(B(bytesize; stream=s); stream=s)
end

Expand Down
25 changes: 25 additions & 0 deletions t.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using AMDGPU

function main()
@show Base.format_bytes(AMDGPU.memory_stats().live)
AMDGPU.record_memory!(true)

res = nothing

for i in 1:2
x = AMDGPU.rand(Float32, 1024)
y = sum(x; dims=1)
if i == 1
res = y
AMDGPU.remove_record!(res)
end
end

AMDGPU.record_memory!(false)
@show Base.format_bytes(AMDGPU.memory_stats().live)
println("Done")

@show res
return
end
main()

0 comments on commit a8bb216

Please sign in to comment.