Skip to content

Commit

Permalink
Try #269:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Oct 7, 2021
2 parents 045fab2 + c555ebd commit 264a7ba
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import KernelAbstractions

export CUDADevice

KernelAbstractions.get_device(::Type{<:CUDA.CuArray}) = CUDADevice()
KernelAbstractions.get_device(::Type{<:CUDA.CUSPARSE.AbstractCuSparseArray}) = CUDADevice()


const FREE_STREAMS = CUDA.CuStream[]
const STREAMS = CUDA.CuStream[]
const STREAM_GC_THRESHOLD = Ref{Int}(16)
Expand Down
5 changes: 5 additions & 0 deletions lib/CUDAKernels/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ using Test
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))

@testset "get_device" begin
@test @inferred(KernelAbstractions.get_device(CUDA.CuArray{Float32,3})) == CUDADevice()
@test @inferred(KernelAbstractions.get_device(CUDA.CUSPARSE.CuSparseMatrixCSC{Float32})) == CUDADevice()
end

if parse(Bool, get(ENV, "CI", "false"))
default = "CPU"
else
Expand Down
3 changes: 3 additions & 0 deletions lib/ROCKernels/src/ROCKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import KernelAbstractions

export ROCDevice

KernelAbstractions.get_device(::Type{<:AMDGPU.ROCArray}) = ROCDevice()


const FREE_QUEUES = HSAQueue[]
const QUEUES = HSAQueue[]
const QUEUE_GC_THRESHOLD = Ref{Int}(16)
Expand Down
4 changes: 4 additions & 0 deletions lib/ROCKernels/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using Test
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))

@test "get_device" begin
@test @inferred(KernelAbstractions.get_device(AMDGPU.ROCArray{Float32,3})) == ROCDevice()
end

if parse(Bool, get(ENV, "CI", "false"))
default = "CPU"
else
Expand Down
18 changes: 18 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,24 @@ abstract type GPU <: Device end

struct CPU <: Device end


"""
KernelAbstractions.get_device(A::AbstractArray)::KernelAbstractions.Device
KernelAbstractions.get_device(TA::Type{<:AbstractArray})::KernelAbstractions.Device
Get a `KernelAbstractions.Device` instance suitable for array `A` resp. array
type `TA`.
"""
function get_device end

get_device(A::AbstractArray) = get_device(typeof(A))

get_device(::Type{<:AbstractArray}) = CPU()

# Would require dependency on GPUArrays:
# get_device(TA::Type{<:GPUArrays.AbstractGPUArray}) = throw(ArgumentError("NoKernelAbstractions.Device type defined for arrays of type $(TA.name.name)"))


include("nditeration.jl")
using .NDIteration
import .NDIteration: get
Expand Down
6 changes: 6 additions & 0 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ end
A[I] = i
end

@testset "get_device" begin
A = rand(5)
@test @inferred(KernelAbstractions.get_device(typeof(A))) == CPU()
@test @inferred(KernelAbstractions.get_device(A)) == KernelAbstractions.get_device(typeof(A))
end

@testset "indextest" begin
# TODO: add test for _group and _local_cartesian
A = ArrayT{Int}(undef, 16, 16)
Expand Down

0 comments on commit 264a7ba

Please sign in to comment.