Skip to content

Commit

Permalink
Add Enzyme Forward mode custom rule.
Browse files Browse the repository at this point in the history
Co-authored-by: Seth Axen <seth@sethaxen.com>
Co-authored-by: "William S. Moses" <gh@wsmoses.com>
  • Loading branch information
3 people committed Mar 23, 2024
1 parent f5100a1 commit b2835c3
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f"
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand All @@ -36,11 +37,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
ChainRulesCoreExt = "ChainRulesCore"
CUDAEnzymeCoreExt = "EnzymeCore"
SpecialFunctionsExt = "SpecialFunctions"

[compat]
Expand All @@ -54,6 +57,7 @@ CUDA_Runtime_jll = "0.12"
ChainRulesCore = "1"
Crayons = "4"
DataFrames = "1"
EnzymeCore = "0.5, 0.6"
ExprTools = "0.1"
GPUArrays = "10.0.1"
GPUCompiler = "0.24, 0.25, 0.26"
Expand All @@ -80,5 +84,6 @@ Statistics = "1"
julia = "1.8"

[extras]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
49 changes: 49 additions & 0 deletions ext/CUDAEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module CUDAEnzymeCoreExt

using CUDA

if isdefined(Base, :get_extension)
using EnzymeCore
using EnzymeCore.EnzymeRules
else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end

function metaf(fn, args::Vararg{Any, N}) where N
EnzymeCore.autodiff_deferred(Forward, fn, Const, args...)
nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(Val(N)) do _
res
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

GC.@preserve args begin
args = ((cudaconvert(a) for a in args)...,)
T2 = (F, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(metaf, TT2)
res = cuf(ofn.val.f, args...; kwargs...)
end

return nothing
end

end
3 changes: 3 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ function __init__()
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
include("../ext/SpecialFunctionsExt.jl")
end
@require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin
include("../ext/CUDAEnzymeCoreExt.jl")
end
end

# ensure that operations executed by the REPL back-end finish before returning,
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
31 changes: 31 additions & 0 deletions test/libraries/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Enzyme

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x)
@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end

A = CUDA.rand(64)
dA = CUDA.ones(64)
A .= (1:1:64)
dA .= 1
Enzyme.autodiff(Forward, square!, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A = CUDA.rand(32)
dA = CUDA.ones(32)
dA2 = CUDA.ones(32)
A .= (1:1:32)
dA .= 1
dA2 .= 3
Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2)))
@test all(dA .≈ (2:2:64))
@test all(dA2 .≈ 3*(2:2:64))

0 comments on commit b2835c3

Please sign in to comment.