Skip to content

Commit

Permalink
Try #269:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Oct 8, 2021
2 parents 045fab2 + 4e49edf commit 12fd5ad
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 10 deletions.
8 changes: 3 additions & 5 deletions examples/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ function matmul!(a, b, c)
println("Matrix size mismatch!")
return nothing
end
if isa(a, Array)
kernel! = matmul_kernel!(CPU(),4)
else
kernel! = matmul_kernel!(CUDADevice(),256)
end
device = KernelAbstractions.get_device(a)
n = device isa GPU ? 256 : 4
kernel! = matmul_kernel!(device, n)
kernel!(a, b, c, ndrange=size(c))
end

Expand Down
8 changes: 3 additions & 5 deletions examples/naive_transpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ function naive_transpose!(a, b)
println("Matrix size mismatch!")
return nothing
end
if isa(a, Array)
kernel! = naive_transpose_kernel!(CPU(),4)
else
kernel! = naive_transpose_kernel!(CUDADevice(),256)
end
device = KernelAbstractions.get_device(a)
n = device isa GPU ? 256 : 4
kernel! = naive_transpose_kernel!(device, n)
kernel!(a, b, ndrange=size(a))
end

Expand Down
4 changes: 4 additions & 0 deletions lib/CUDAKernels/src/CUDAKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import KernelAbstractions

export CUDADevice

KernelAbstractions.get_device(::Type{<:CUDA.CuArray}) = CUDADevice()
KernelAbstractions.get_device(::Type{<:CUDA.CUSPARSE.AbstractCuSparseArray}) = CUDADevice()


const FREE_STREAMS = CUDA.CuStream[]
const STREAMS = CUDA.CuStream[]
const STREAM_GC_THRESHOLD = Ref{Int}(16)
Expand Down
5 changes: 5 additions & 0 deletions lib/CUDAKernels/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ using Test
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))

@testset "get_device" begin
@test @inferred(KernelAbstractions.get_device(CUDA.CuArray{Float32,3})) == CUDADevice()
@test @inferred(KernelAbstractions.get_device(CUDA.CUSPARSE.CuSparseMatrixCSC{Float32})) == CUDADevice()
end

if parse(Bool, get(ENV, "CI", "false"))
default = "CPU"
else
Expand Down
3 changes: 3 additions & 0 deletions lib/ROCKernels/src/ROCKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import KernelAbstractions

export ROCDevice

KernelAbstractions.get_device(::Type{<:AMDGPU.ROCArray}) = ROCDevice()


const FREE_QUEUES = HSAQueue[]
const QUEUES = HSAQueue[]
const QUEUE_GC_THRESHOLD = Ref{Int}(16)
Expand Down
4 changes: 4 additions & 0 deletions lib/ROCKernels/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using Test
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))

@test "get_device" begin
@test @inferred(KernelAbstractions.get_device(AMDGPU.ROCArray{Float32,3})) == ROCDevice()
end

if parse(Bool, get(ENV, "CI", "false"))
default = "CPU"
else
Expand Down
18 changes: 18 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,24 @@ abstract type GPU <: Device end

struct CPU <: Device end


"""
KernelAbstractions.get_device(A::AbstractArray)::KernelAbstractions.Device
KernelAbstractions.get_device(TA::Type{<:AbstractArray})::KernelAbstractions.Device
Get a `KernelAbstractions.Device` instance suitable for array `A` resp. array
type `TA`.
"""
function get_device end

get_device(A::AbstractArray) = get_device(typeof(A))

get_device(::Type{<:AbstractArray}) = CPU()

# Would require dependency on GPUArrays:
# get_device(TA::Type{<:GPUArrays.AbstractGPUArray}) = throw(ArgumentError("NoKernelAbstractions.Device type defined for arrays of type $(TA.name.name)"))


include("nditeration.jl")
using .NDIteration
import .NDIteration: get
Expand Down
6 changes: 6 additions & 0 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ end
A[I] = i
end

@testset "get_device" begin
A = rand(5)
@test @inferred(KernelAbstractions.get_device(typeof(A))) == CPU()
@test @inferred(KernelAbstractions.get_device(A)) == KernelAbstractions.get_device(typeof(A))
end

@testset "indextest" begin
# TODO: add test for _group and _local_cartesian
A = ArrayT{Int}(undef, 16, 16)
Expand Down

0 comments on commit 12fd5ad

Please sign in to comment.