diff --git a/src/GPUifyLoops.jl b/src/GPUifyLoops.jl index b8c67bc..a73aa4f 100644 --- a/src/GPUifyLoops.jl +++ b/src/GPUifyLoops.jl @@ -65,6 +65,7 @@ Launch a kernel on the GPU. `kwargs` are passed to `@cuda` normally passed to `@cuda`. """ launch(::CPU, f, args...; kwargs...) = f(args...) +signature(::CPU, f, args...) = f, map(Core.Typeof, args) """ launch_config(::F, maxthreads, args...; kwargs...) @@ -134,6 +135,13 @@ end end return nothing end + + function signature(::CUDA, f::F, args...) where F + args = (ctx, f, args...) + kernel_args = map(cudaconvert, args) + kernel_tt = Tuple{map(Core.Typeof, kernel_args)...} + return Cassette.overdub, kernel_tt + end end isdevice(::CPU) = false