Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle cache improvements #2352

Merged
merged 7 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,12 @@ steps:
try
Pkg.instantiate()
catch
# if we fail to instantiate, assume that we need a newer CUDA.jl
Pkg.develop(path=".")
# if we fail to instantiate, assume that we need newer dependencies
deps = [PackageSpec(path=".")]
if "{{matrix.package}}" == "cuTensorNet"
push!(deps, PackageSpec(path="lib/cutensor"))
end
Pkg.develop(deps)
end

Pkg.add("CUDA_Runtime_jll")
Expand Down
79 changes: 46 additions & 33 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,20 @@ function math_mode!(handle, mode)
return
end

# cache for created, but unused handles
const idle_handles = HandleCache{CuContext,cublasHandle_t}()
const idle_xt_handles = HandleCache{Any,cublasXtHandle_t}()

## handles

function handle_ctor(ctx)
context!(ctx) do
cublasCreate()
end
end
function handle_dtor(ctx, handle)
context!(ctx; skip_destroyed=true) do
cublasDestroy_v2(handle)
end
end
const idle_handles = HandleCache{CuContext,cublasHandle_t}(handle_ctor, handle_dtor)

function handle()
cuda = CUDA.active_state()
Expand All @@ -86,20 +97,12 @@ function handle()

# get library state
@noinline function new_state(cuda)
new_handle = pop!(idle_handles, cuda.context) do
cublasCreate()
end

new_handle = pop!(idle_handles, cuda.context)
finalizer(current_task()) do task
push!(idle_handles, cuda.context, new_handle) do
context!(cuda.context; skip_destroyed=true) do
cublasDestroy_v2(new_handle)
end
end
push!(idle_handles, cuda.context, new_handle)
end

cublasSetStream_v2(new_handle, cuda.stream)

math_mode!(new_handle, cuda.math_mode)

(; handle=new_handle, cuda.stream, cuda.math_mode)
Expand Down Expand Up @@ -129,6 +132,34 @@ function handle()
return state.handle
end


## xt handles

function xt_handle_ctor(ctx)
context!(ctx) do
cublasXtCreate()
end
end
function xt_handle_dtor(ctx, handle)
context!(ctx; skip_destroyed=true) do
cublasXtDestroy(handle)
end
end
const idle_xt_handles =
HandleCache{CuContext,cublasXtHandle_t}(xt_handle_ctor, xt_handle_dtor)

function devices!(devs::Vector{CuDevice})
task_local_storage(:CUBLASxt_devices, sort(devs; by=deviceid))
return
end

devices() = get!(task_local_storage(), :CUBLASxt_devices) do
# by default, select all devices
sort(collect(CUDA.devices()); by=deviceid)
end::Vector{CuDevice}

ndevices() = length(devices())

function xt_handle()
cuda = CUDA.active_state()

Expand All @@ -147,15 +178,9 @@ function xt_handle()

# get library state
@noinline function new_state(cuda)
new_handle = pop!(idle_xt_handles, cuda.context) do
cublasXtCreate()
end

new_handle = pop!(idle_xt_handles, cuda.context)
finalizer(current_task()) do task
push!(idle_xt_handles, cuda.context, new_handle) do
# TODO: which context do we need to destroy this on?
cublasXtDestroy(new_handle)
end
push!(idle_xt_handles, cuda.context, new_handle)
end

devs = convert.(Cint, devices())
Expand All @@ -170,18 +195,6 @@ function xt_handle()
return state.handle
end

function devices!(devs::Vector{CuDevice})
task_local_storage(:CUBLASxt_devices, sort(devs; by=deviceid))
return
end

devices() = get!(task_local_storage(), :CUBLASxt_devices) do
# by default, select all devices
sort(collect(CUDA.devices()); by=deviceid)
end::Vector{CuDevice}

ndevices() = length(devices())


## logging

Expand Down
2 changes: 1 addition & 1 deletion lib/cudnn/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CUDNN_jll = "62b44479-cb7b-5706-934f-f13b2eb2e645"

[compat]
CEnum = "0.2, 0.3, 0.4, 0.5"
CUDA = "~5.3, ~5.4"
CUDA = "~5.4"
CUDA_Runtime_Discovery = "0.2"
CUDNN_jll = "~9.0"
julia = "1.8"
27 changes: 16 additions & 11 deletions lib/cudnn/src/cuDNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,20 @@ function math_mode(mode=CUDA.math_mode())
end
end

# cache for created, but unused handles
const idle_handles = HandleCache{CuContext,cudnnHandle_t}()

## handles

function handle_ctor(ctx)
context!(ctx) do
cudnnCreate()
end
end
function handle_dtor(ctx, handle)
context!(ctx; skip_destroyed=true) do
cudnnDestroy(handle)
end
end
const idle_handles = HandleCache{CuContext,cudnnHandle_t}(handle_ctor, handle_dtor)

function handle()
cuda = CUDA.active_state()
Expand All @@ -76,16 +88,9 @@ function handle()

# get library state
@noinline function new_state(cuda)
new_handle = pop!(idle_handles, cuda.context) do
cudnnCreate()
end

new_handle = pop!(idle_handles, cuda.context)
finalizer(current_task()) do task
push!(idle_handles, cuda.context, new_handle) do
context!(cuda.context; skip_destroyed=true) do
cudnnDestroy(new_handle)
end
end
push!(idle_handles, cuda.context, new_handle)
end

cudnnSetStream(new_handle, cuda.stream)
Expand Down
30 changes: 18 additions & 12 deletions lib/cufft/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region)
end
if ((region...,) == ((1:nrank)...,))
# handle simple case, transforming the first nrank dimensions, ... simply! (for robustness)
# arguments are: plan, rank, transform-sizes, inembed, istride, idist, onembed, ostride, odist, type batch
# arguments are: plan, rank, transform-sizes, inembed, istride, idist, onembed, ostride, odist, type batch
cufftMakePlanMany(handle, nrank, Cint[rsz...], C_NULL, 1, 1, C_NULL, 1, 1,
xtype, batch, worksize_ref)
else
Expand Down Expand Up @@ -151,29 +151,35 @@ function cufftMakePlan(xtype::cufftType_t, xdims::Dims, region)
handle, worksize_ref[]
end

# plan cache
const cufftHandleCacheKey = Tuple{CuContext, cufftType_t, Dims, Any}
const idle_handles = HandleCache{cufftHandleCacheKey, cufftHandle}()
function cufftGetPlan(args...)

ctx = context()
handle = pop!(idle_handles, (ctx, args...)) do
## plan cache

const cufftHandleCacheKey = Tuple{CuContext, cufftType_t, Dims, Any}
function handle_ctor((ctx, args...))
context!(ctx) do
# make the plan
handle, worksize = cufftMakePlan(args...)

# NOTE: we currently do not use the worksize to allocate our own workarea,
# instead relying on the automatic allocation strategy.
handle
end
end
function handle_dtor((ctx, args...), handle)
context!(ctx; skip_destroyed=true) do
cufftDestroy(handle)
end
end
const idle_handles = HandleCache{cufftHandleCacheKey, cufftHandle}(handle_ctor, handle_dtor)

function cufftGetPlan(args...)
ctx = context()
handle = pop!(idle_handles, (ctx, args...))

# assign to the current stream
cufftSetStream(handle, stream())

return handle
end
function cufftReleasePlan(plan)
push!(idle_handles, plan) do
cufftDestroy(plan)
end

push!(idle_handles, plan)
end
27 changes: 18 additions & 9 deletions lib/curand/CURAND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,21 @@ include("wrappers.jl")
# high-level integrations
include("random.jl")

# cache for created, but unused handles
const idle_curand_rngs = HandleCache{CuContext,RNG}()

## handles

function handle_ctor(ctx)
context!(ctx) do
RNG()
end
end
function handle_dtor(ctx, handle)
context!(ctx; skip_destroyed=true) do
# no need to do anything, as the RNG is collected by its finalizer
# TODO: early free?
end
end
const idle_curand_rngs = HandleCache{CuContext,RNG}(handle_ctor, handle_dtor)

function default_rng()
cuda = CUDA.active_state()
Expand All @@ -35,17 +48,13 @@ function default_rng()

# get library state
@noinline function new_state(cuda)
new_rng = pop!(idle_curand_rngs, cuda.context) do
RNG()
end

new_rng = pop!(idle_curand_rngs, cuda.context)
finalizer(current_task()) do task
push!(idle_curand_rngs, cuda.context, new_rng) do
# no need to do anything, as the RNG is collected by its finalizer
end
push!(idle_curand_rngs, cuda.context, new_rng)
end

Random.seed!(new_rng)

(; rng=new_rng)
end
state = get!(states, cuda.context) do
Expand Down
Loading