Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor improvements for library wrappers #1207

Merged
merged 4 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/src/lib/driver.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ CUDA.total_memory

```@docs
CuStream
CUDA.query(::CuStream)
CUDA.isdone(::CuStream)
priority_range
priority
synchronize(::CuStream)
Expand All @@ -153,9 +153,9 @@ CUDA.@sync
For specific use cases, special streams are available:

```@docs
CuDefaultStream
CuStreamLegacy
CuStreamPerThread
default_stream
legacy_stream
per_thread_stream
```

## Event Management
Expand All @@ -164,7 +164,7 @@ CuStreamPerThread
CuEvent
record
synchronize(::CuEvent)
CUDA.query(::CuEvent)
CUDA.isdone(::CuEvent)
CUDA.wait(::CuEvent)
elapsed
CUDA.@elapsed
Expand Down
3 changes: 0 additions & 3 deletions lib/cudadrv/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ const DEVICE_INVALID = _CuDevice(CUdevice(-2))

Base.convert(::Type{CUdevice}, dev::CuDevice) = dev.handle

Base.:(==)(a::CuDevice, b::CuDevice) = a.handle == b.handle
Base.hash(dev::CuDevice, h::UInt) = hash(dev.handle, h)

function Base.show(io::IO, ::MIME"text/plain", dev::CuDevice)
print(io, "CuDevice($(dev.handle)): ")
if dev == DEVICE_CPU
Expand Down
4 changes: 2 additions & 2 deletions lib/cudadrv/events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ Waits for an event to complete.
synchronize(e::CuEvent) = cuEventSynchronize(e)

"""
query(e::CuEvent)
isdone(e::CuEvent)

Return `false` if there is outstanding work preceding the most recent
call to `record(e)` and `true` if all captured work has been completed.
"""
function query(e::CuEvent)
function isdone(e::CuEvent)
res = unsafe_cuEventQuery(e)
if res == ERROR_NOT_READY
return false
Expand Down
30 changes: 15 additions & 15 deletions lib/cudadrv/stream.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Stream management

export
CuStream, CuDefaultStream, CuStreamLegacy, CuStreamPerThread,
CuStream, default_stream, legacy_stream, per_thread_stream,
priority, priority_range, synchronize, device_synchronize

"""
Expand Down Expand Up @@ -29,15 +29,15 @@ mutable struct CuStream
return obj
end

global CuDefaultStream() = new(convert(CUstream, C_NULL), nothing)
global default_stream() = new(convert(CUstream, C_NULL), nothing)

global CuStreamLegacy() = new(convert(CUstream, 1), nothing)
global legacy_stream() = new(convert(CUstream, 1), nothing)

global CuStreamPerThread() = new(convert(CUstream, 2), nothing)
global per_thread_stream() = new(convert(CUstream, 2), nothing)
end

"""
CuDefaultStream()
default_stream()

Return the default stream.

Expand All @@ -46,21 +46,21 @@ Return the default stream.
It is generally better to use `stream()` to get a stream object that's local to the
current task. That way, operations scheduled in other tasks can overlap.
"""
CuDefaultStream()
default_stream()

"""
CuStreamLegacy()
legacy_stream()

Return a special object to use use an implicit stream with legacy synchronization behavior.

You can use this stream to perform operations that should block on all streams (with the
exception of streams created with `STREAM_NON_BLOCKING`). This matches the old pre-CUDA 7
global stream behavior.
"""
CuStreamLegacy()
legacy_stream()

"""
CuStreamPerThread()
per_thread_stream()

Return a special object to use an implicit stream with per-thread synchronization behavior.
This stream object is normally meant to be used with APIs that do not have per-thread
Expand All @@ -72,7 +72,7 @@ versions of their APIs (i.e. without a `ptsz` or `ptds` suffix).
gets its own non-blocking stream, and multithreading in Julia is typically
accomplished using tasks.
"""
CuStreamPerThread()
per_thread_stream()

Base.unsafe_convert(::Type{CUstream}, s::CuStream) = s.handle

Expand All @@ -92,12 +92,12 @@ function Base.show(io::IO, stream::CuStream)
end

"""
query(s::CuStream)
isdone(s::CuStream)

Return `false` if a stream is busy (has task running or queued)
and `true` if that stream is free.
"""
function query(s::CuStream)
function isdone(s::CuStream)
res = unsafe_cuStreamQuery(s)
if res == ERROR_NOT_READY
return false
Expand All @@ -119,7 +119,7 @@ See also: [`device_synchronize`](@ref)
"""
function synchronize(stream::CuStream=stream(); blocking::Bool=true)
# fast path
query(stream) && @goto(exit)
isdone(stream) && @goto(exit)

# minimize latency of short operations by busy-waiting,
# initially without even yielding to other tasks
Expand All @@ -132,7 +132,7 @@ function synchronize(stream::CuStream=stream(); blocking::Bool=true)
else
yield()
end
query(stream) && @goto(exit)
isdone(stream) && @goto(exit)
spins += 1
end

Expand All @@ -155,7 +155,7 @@ Block for the current device's tasks to complete. This is a heavyweight operatio
you only need to call [`synchronize`](@ref) which only synchronizes the stream associated
with the current task.
"""
device_synchronize() = synchronize(CuStreamLegacy())
device_synchronize() = synchronize(legacy_stream())

"""
priority_range()
Expand Down
2 changes: 1 addition & 1 deletion lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function CUDA.unsafe_free!(plan::CuFFTPlan, stream::CuStream=stream())
unsafe_free!(plan.workarea, stream)
end

unsafe_finalize!(plan::CuFFTPlan) = unsafe_free!(plan, CuDefaultStream())
unsafe_finalize!(plan::CuFFTPlan) = unsafe_free!(plan, default_stream())

mutable struct cCuFFTPlan{T<:cufftNumber,K,inplace,N} <: CuFFTPlan{T,K,inplace}
handle::cufftHandle
Expand Down
11 changes: 11 additions & 0 deletions lib/nvml/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ end

Base.unsafe_convert(::Type{nvmlDevice_t}, dev::Device) = dev.handle

function Base.show(io::IO, ::MIME"text/plain", dev::Device)
print(io, "NVML.Device($(index(dev))): ")
print(io, "$(name(dev))")
end



# iteration
Expand Down Expand Up @@ -75,6 +80,12 @@ function serial(dev::Device)
return unsafe_string(pointer(buf))
end

function index(dev::Device)
index = Ref{Cuint}()
nvmlDeviceGetIndex(dev, index)
return Int(index[])
end

# watt
function power_usage(dev::Device)
ref = Ref{Cuint}()
Expand Down
4 changes: 2 additions & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ earlier to reduce pressure on the memory allocator.

By default, the operation is performed on the task-local stream. During task or process
finalization however, that stream may be destroyed already, so be sure to specify a safe
stream (i.e. `CuDefaultStream()`, which will ensure the operation will block on other
stream (i.e. `default_stream()`, which will ensure the operation will block on other
streams) when calling this function from a finalizer. For simplicity, the `unsafe_finalize!`
function does exactly that.
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ function unsafe_finalize!(xs::CuArray)
# stream, it synchronizes "too much". we could do better, e.g., by keeping track of all
# streams involved, or by refcounting uses and decrementing that refcount after the
# operation using `cuLaunchHostFunc`. See CUDA.jl#778 and CUDA.jl#780 for details.
unsafe_free!(xs, CuDefaultStream())
unsafe_free!(xs, default_stream())
end


Expand Down
6 changes: 6 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@
@deprecate CuCurrentContext() current_context()
@deprecate CuContext(ptr::Union{Ptr,CuPtr}) context(ptr)
@deprecate CuDevice(ptr::Union{Ptr,CuPtr}) device(ptr)

@deprecate CuDefaultStream() default_stream()
@deprecate CuStreamLegacy() legacy_stream()
@deprecate CuStreamPerThread() per_thread_stream()
@deprecate query(s::CuStream) isdone(s)
@deprecate query(e::CuEvent) isdone(e)
6 changes: 3 additions & 3 deletions test/cudadrv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ CuEvent(CUDA.EVENT_BLOCKING_SYNC | CUDA.EVENT_DISABLE_TIMING)
end

@testset "event query" begin
event = CuEvent()
@test CUDA.query(event) == true
event = CuEvent()
@test CUDA.isdone(event)
end

end
Expand Down Expand Up @@ -835,7 +835,7 @@ end

s = CuStream()
synchronize(s)
@test CUDA.query(s) == true
@test CUDA.isdone(s)

let s2 = CuStream()
@test s != s2
Expand Down
4 changes: 4 additions & 0 deletions test/nvml.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ end
@testset "devices" begin
let dev = NVML.Device(0)
@test dev == first(NVML.devices())
@test NVML.index(dev) == 0

str = sprint(io->show(io, "text/plain", dev))
@test occursin("NVML.Device(0)", str)
end

cuda_dev = CuDevice(0)
Expand Down