Skip to content

Commit

Permalink
Support more kwarg syntax with kernel launches (#380)
Browse files Browse the repository at this point in the history
Also clean-up KA.jl back-end.
  • Loading branch information
maleadt authored Dec 6, 2023
1 parent 923a372 commit 3b97437
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 94 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361"
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
Expand All @@ -36,7 +36,7 @@ NEO_jll = "=23.17.26241"
Preferences = "1"
SPIRV_LLVM_Translator_unified_jll = "0.3"
SpecialFunctions = "1.3, 2"
UnsafeAtomicsLLVM = "0.1"
StaticArrays = "1"
julia = "1.8"
oneAPI_Level_Zero_Loader_jll = "1.9"
oneAPI_Support_jll = "~0.2.2"
10 changes: 9 additions & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ const LAUNCH_KWARGS = [:groups, :items, :queue]

macro oneapi(ex...)
call = ex[end]
kwargs = ex[1:end-1]
kwargs = map(ex[1:end-1]) do kwarg
if kwarg isa Symbol
:($kwarg = $kwarg)
elseif Meta.isexpr(kwarg, :(=))
kwarg
else
throw(ArgumentError("Invalid keyword argument '$kwarg'"))
end
end

# destructure the kernel call
Meta.isexpr(call, :call) || throw(ArgumentError("second argument to @oneapi should be a function call"))
Expand Down
4 changes: 2 additions & 2 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::oneWrappedArray{T},
# perform the actual reduction
if reduce_groups == 1
# we can cover the dimensions to reduce using a single group
@oneapi items=items groups=groups partial_mapreduce_device(
@oneapi items groups partial_mapreduce_device(
f, op, init, Val(items), Rreduce, Rother, R′, A)
else
# we need multiple steps to cover all values to reduce
Expand All @@ -172,7 +172,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::oneWrappedArray{T},
# without an explicit initializer we need to copy from the output container
partial .= R
end
@oneapi items=items groups=groups partial_mapreduce_device(
@oneapi items groups partial_mapreduce_device(
f, op, init, Val(items), Rreduce, Rother, partial, A)

GPUArrays.mapreducedim!(identity, op, R′, partial; init=init)
Expand Down
172 changes: 83 additions & 89 deletions src/oneAPIKernels.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,55 @@
module oneAPIKernels

import KernelAbstractions
import oneAPI
import oneAPI: oneL0, @device_override
import GPUCompiler
using ..oneAPI
using ..oneAPI: @device_override

import UnsafeAtomicsLLVM
import KernelAbstractions as KA

import StaticArrays

import Adapt


## Back-end Definition

struct oneAPIBackend <: KernelAbstractions.GPU
end
export oneAPIBackend

KernelAbstractions.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.oneArray{T}(undef, dims)
KernelAbstractions.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.zeros(T, dims)
KernelAbstractions.ones(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.ones(T, dims)
struct oneAPIBackend <: KA.GPU
end

# Import through parent
import KernelAbstractions: StaticArrays, Adapt
import .StaticArrays: MArray
KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneArray{T}(undef, dims)
KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.zeros(T, dims)
KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.ones(T, dims)

KernelAbstractions.get_backend(::oneAPI.oneArray) = oneAPIBackend()
KA.get_backend(::oneArray) = oneAPIBackend()
# TODO should be non-blocking
KernelAbstractions.synchronize(::oneAPIBackend) = oneL0.synchronize()
KernelAbstractions.supports_float64(::oneAPIBackend) = false # TODO is this device dependent?
KA.synchronize(::oneAPIBackend) = oneL0.synchronize()
KA.supports_float64(::oneAPIBackend) = false # TODO: Check if this is device dependent

Adapt.adapt_storage(::oneAPIBackend, a::Array) = Adapt.adapt(oneAPI.oneArray, a)
Adapt.adapt_storage(::oneAPIBackend, a::oneAPI.oneArray) = a
Adapt.adapt_storage(::KernelAbstractions.CPU, a::oneAPI.oneArray) = convert(Array, a)
Adapt.adapt_storage(::oneAPIBackend, a::Array) = Adapt.adapt(oneArray, a)
Adapt.adapt_storage(::oneAPIBackend, a::oneArray) = a
Adapt.adapt_storage(::KA.CPU, a::oneArray) = convert(Array, a)

##
# copyto!
##

## Memory Operations

function KernelAbstractions.copyto!(::oneAPIBackend, A, B)
function KA.copyto!(::oneAPIBackend, A, B)
copyto!(A, B)
# TODO device to host copies in oneAPI.jl are synchronizing.
# TODO: Address device to host copies in jl being synchronizing
end

import KernelAbstractions: Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config

###
# Kernel launch
###
function launch_config(kernel::Kernel{oneAPIBackend}, ndrange, workgroupsize)
## Kernel Launch

function KA.mkcontext(kernel::KA.Kernel{oneAPIBackend}, _ndrange, iterspace)
KA.CompilerMetadata{KA.ndrange(kernel), KA.DynamicCheck}(_ndrange, iterspace)
end
function KA.mkcontext(kernel::KA.Kernel{oneAPIBackend}, I, _ndrange, iterspace,
::Dynamic) where Dynamic
KA.CompilerMetadata{KA.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
end

function KA.launch_config(kernel::KA.Kernel{oneAPIBackend}, ndrange, workgroupsize)
if ndrange isa Integer
ndrange = (ndrange,)
end
Expand All @@ -52,16 +58,16 @@ function launch_config(kernel::Kernel{oneAPIBackend}, ndrange, workgroupsize)
end

# partition checked that the ndrange's agreed
if KernelAbstractions.ndrange(kernel) <: StaticSize
if KA.ndrange(kernel) <: KA.StaticSize
ndrange = nothing
end

iterspace, dynamic = if KernelAbstractions.workgroupsize(kernel) <: DynamicSize &&
iterspace, dynamic = if KA.workgroupsize(kernel) <: KA.DynamicSize &&
workgroupsize === nothing
# use ndrange as preliminary workgroupsize for autotuning
partition(kernel, ndrange, ndrange)
KA.partition(kernel, ndrange, ndrange)
else
partition(kernel, ndrange, workgroupsize)
KA.partition(kernel, ndrange, workgroupsize)
end

return ndrange, workgroupsize, iterspace, dynamic
Expand All @@ -76,108 +82,96 @@ function threads_to_workgroupsize(threads, ndrange)
end
end

function (obj::Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize=nothing)
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize=nothing)
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, workgroupsize)
# this might not be the final context, since we may tune the workgroupsize
ctx = mkcontext(obj, ndrange, iterspace)
kernel = oneAPI.@oneapi launch=false obj.f(ctx, args...)
ctx = KA.mkcontext(obj, ndrange, iterspace)
kernel = @oneapi launch=false obj.f(ctx, args...)

# figure out the optimal workgroupsize automatically
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
items = oneAPI.suggest_groupsize(kernel.fun, prod(ndrange)).x
# XXX: the z dimension of the suggested group size is often non-zero. use this?
workgroupsize = threads_to_workgroupsize(items, ndrange)
iterspace, dynamic = partition(obj, ndrange, workgroupsize)
ctx = mkcontext(obj, ndrange, iterspace)
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
ctx = KA.mkcontext(obj, ndrange, iterspace)
end

nblocks = length(blocks(iterspace))
threads = length(workitems(iterspace))
groups = length(KA.blocks(iterspace))
items = length(KA.workitems(iterspace))

if nblocks == 0
if groups == 0
return nothing
end

# Launch kernel
kernel(ctx, args...; items=threads, groups=nblocks)
kernel(ctx, args...; items, groups)

return nothing
end

import KernelAbstractions: CompilerMetadata, DynamicCheck, LinearIndices
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds

function mkcontext(kernel::Kernel{oneAPIBackend}, _ndrange, iterspace)
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
end
function mkcontext(kernel::Kernel{oneAPIBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
metadata = CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
end
## Indexing Functions

@device_override @inline function __index_Local_Linear(ctx)
return oneAPI.get_local_id(0)
@device_override @inline function KA.__index_Local_Linear(ctx)
return get_local_id(0)
end

@device_override @inline function __index_Group_Linear(ctx)
return oneAPI.get_group_id(0)
@device_override @inline function KA.__index_Group_Linear(ctx)
return get_group_id(0)
end

@device_override @inline function __index_Global_Linear(ctx)
I = @inbounds expand(__iterspace(ctx), oneAPI.get_group_id(0), oneAPI.get_local_id(0))
# TODO: This is unfortunate, can we get the linear index cheaper
@inbounds LinearIndices(__ndrange(ctx))[I]
@device_override @inline function KA.__index_Global_Linear(ctx)
return get_global_id(0)
end

@device_override @inline function __index_Local_Cartesian(ctx)
@inbounds workitems(__iterspace(ctx))[oneAPI.get_local_id(0)]
@device_override @inline function KA.__index_Local_Cartesian(ctx)
@inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id(0)]
end

@device_override @inline function __index_Group_Cartesian(ctx)
@inbounds blocks(__iterspace(ctx))[oneAPI.get_group_id(0)]
@device_override @inline function KA.__index_Group_Cartesian(ctx)
@inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id(0)]
end

@device_override @inline function __index_Global_Cartesian(ctx)
return @inbounds expand(__iterspace(ctx), oneAPI.get_group_id(0), oneAPI.get_local_id(0))
@device_override @inline function KA.__index_Global_Cartesian(ctx)
return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(0), get_local_id(0))
end

@device_override @inline function __validindex(ctx)
if __dynamic_checkbounds(ctx)
I = @inbounds expand(__iterspace(ctx), oneAPI.get_group_id(0), oneAPI.get_local_id(0))
return I in __ndrange(ctx)
@device_override @inline function KA.__validindex(ctx)
if KA.__dynamic_checkbounds(ctx)
I = @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(0), get_local_id(0))
return I in KA.__ndrange(ctx)
else
return true
end
end

import KernelAbstractions: groupsize, __groupsize, __workitems_iterspace
import KernelAbstractions: SharedMemory, Scratchpad, __synchronize, __size

###
# GPU implementation of shared memory
###
@device_override @inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
## Shared and Scratch Memory

@device_override @inline function KA.SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
ptr = oneAPI.emit_localmemory(T, Val(prod(Dims)))
oneAPI.oneDeviceArray(Dims, ptr)
oneDeviceArray(Dims, ptr)
end

###
# GPU implementation of scratch memory
# - private memory for each workitem
###

@device_override @inline function Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
StaticArrays.MArray{__size(Dims), T}(undef)
@device_override @inline function KA.Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
StaticArrays.MArray{KA.__size(Dims), T}(undef)
end

@device_override @inline function __synchronize()
oneAPI.barrier()

## Synchronization and Printing

@device_override @inline function KA.__synchronize()
barrier()
end

@device_override @inline function __print(args...)
@device_override @inline function KA.__print(args...)
oneAPI._print(args...)
end

KernelAbstractions.argconvert(::Kernel{oneAPIBackend}, arg) = oneAPI.kernel_convert(arg)

## Other

KA.argconvert(::KA.Kernel{oneAPIBackend}, arg) = kernel_convert(arg)

end
4 changes: 4 additions & 0 deletions test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ end
@testset "launch configuration" begin
@oneapi dummy()

items = 1
@oneapi items dummy()
@oneapi items=1 dummy()
@oneapi items=(1,1) dummy()
@oneapi items=(1,1,1) dummy()

groups = 1
@oneapi groups dummy()
@oneapi groups=1 dummy()
@oneapi groups=(1,1) dummy()
@oneapi groups=(1,1,1) dummy()
Expand Down

0 comments on commit 3b97437

Please sign in to comment.