Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme Forward mode custom rule #1869

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,84 @@ module EnzymeCoreExt
using CUDA
import CUDA: GPUCompiler, CUDABackend

isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore)
if isdefined(Base, :get_extension)
using EnzymeCore
using EnzymeCore.EnzymeRules
else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end

function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type))
mi = GPUCompiler.methodinstance(F, TT)
return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device()))
end

function metaf(fn, args::Vararg{Any, N}) where N
EnzymeCore.autodiff_deferred(Forward, fn, Const, args...)
nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(Val(N)) do _
res
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
::Type{RT}, x::IT) where {RT, IT}
if RT <: Duplicated
return Duplicated(ofn.val(x.val), ofn.val(x.dval))
elseif RT <: Const
return ofn.val(x.val)
elseif RT <: DuplicatedNoNeed
return ofn.val(x.val)
else
tup = ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(x.dval[i])
end
if RT <: BatchDuplicated
return BatchDuplicated(ofv.val(x.val, tup))
else
return tup
end
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(Val(N)) do _
res
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

GC.@preserve args begin
args = ((cudaconvert(a) for a in args)...,)
T2 = (F, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(metaf, TT2)
res = cuf(ofn.val.f, args...; kwargs...)
end

return nothing
end

end # module

1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand Down
31 changes: 31 additions & 0 deletions test/libraries/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
using EnzymeCore
using GPUCompiler
using Enzyme

@testset "compiler_job_from_backend" begin
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
end

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x)
@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end

A = CUDA.rand(64)
dA = CUDA.ones(64)
A .= (1:1:64)
dA .= 1
Enzyme.autodiff(Forward, square!, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A = CUDA.rand(32)
dA = CUDA.ones(32)
dA2 = CUDA.ones(32)
A .= (1:1:32)
dA .= 1
dA2 .= 3
Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2)))
@test all(dA .≈ (2:2:64))
@test all(dA2 .≈ 3*(2:2:64))