Skip to content
This repository has been archived by the owner on May 17, 2020. It is now read-only.

Commit

Permalink
Merge pull request #97 from vchuravy/lcw/signature
Browse files Browse the repository at this point in the history
Add signature function
  • Loading branch information
vchuravy authored Oct 2, 2019
2 parents 19ebf17 + f9f750d commit 5ca1301
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/GPUifyLoops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Launch a kernel on the GPU. `kwargs` are passed to `@cuda`
normally passed to `@cuda`.
"""
launch(::CPU, f, args...; kwargs...) = f(args...)
signature(::CPU, f, args...) = f, map(Core.Typeof, args)

"""
launch_config(::F, maxthreads, args...; kwargs...)
Expand Down Expand Up @@ -134,6 +135,13 @@ end
end
return nothing
end

function signature(::CUDA, f::F, args...) where F
args = (ctx, f, args...)
kernel_args = map(cudaconvert, args)
kernel_tt = Tuple{map(Core.Typeof, kernel_args)...}
return Cassette.overdub, kernel_tt
end
end

isdevice(::CPU) = false
Expand Down

0 comments on commit 5ca1301

Please sign in to comment.