Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport Enzyme extension #2375

Merged
merged 8 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ CUDA_Runtime_Discovery = "1af6417a-86b4-443c-805f-a4643ffb695f"
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand All @@ -37,10 +38,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
ChainRulesCoreExt = "ChainRulesCore"
EnzymeCoreExt = "EnzymeCore"
SpecialFunctionsExt = "SpecialFunctions"

[compat]
Expand All @@ -55,6 +58,7 @@ ChainRulesCore = "1"
Crayons = "4"
DataFrames = "1"
ExprTools = "0.1"
EnzymeCore = "0.7.1"
GPUArrays = "10.0.1"
GPUCompiler = "0.24, 0.25, 0.26"
KernelAbstractions = "0.9.2"
Expand All @@ -81,4 +85,5 @@ julia = "1.8"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
196 changes: 196 additions & 0 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# compatibility with EnzymeCore

module EnzymeCoreExt

using CUDA
import CUDA: GPUCompiler, CUDABackend

if isdefined(Base, :get_extension)
using EnzymeCore
using EnzymeCore.EnzymeRules
else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end

function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Type), @nospecialize(TT::Type))
mi = GPUCompiler.methodinstance(F, TT)
return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device()))
end

function metaf(fn, args::Vararg{Any, N}) where N
EnzymeCore.autodiff_deferred(Forward, fn, Const, args...)
nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{<:Duplicated}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT}
res = ofn.val(f.val, tt.val; kwargs...)
return Duplicated(res, res)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)},
::Type{BatchDuplicated{T,N}}, f::Const{F},
tt::Const{TT}; kwargs...) where {F,TT,T,N}
res = ofn.val(f.val, tt.val; kwargs...)
return BatchDuplicated(res, ntuple(Val(N)) do _
res
end)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
::Type{RT}, x::IT) where {RT, IT}
if RT <: Duplicated
return Duplicated(ofn.val(x.val), ofn.val(x.dval))
elseif RT <: Const
return ofn.val(x.val)
elseif RT <: DuplicatedNoNeed
return ofn.val(x.val)
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
ofn.val(x.dval[i])
end
if RT <: BatchDuplicated
return BatchDuplicated(ofv.val(x.val), tup)
else
return tup
end
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)},
::Type{RT}, args::NTuple{N, EnzymeCore.Annotation}; kwargs...) where {RT, N}
pargs = ntuple(Val(N)) do i
Base.@_inline_meta
args.val
end
res = ofn.val(pargs...; kwargs...)

if RT <: Duplicated
return Duplicated(res, res)
elseif RT <: Const
return res
elseif RT <: DuplicatedNoNeed
return res
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
res
end
if RT <: BatchDuplicated
return BatchDuplicated(res, tup)
else
return tup
end
end
end

function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}},
::Type{Const{Nothing}}, args...;
kwargs...) where {F,TT}

GC.@preserve args begin
args = ((cudaconvert(a) for a in args)...,)
T2 = (F, (typeof(a) for a in args)...)
TT2 = Tuple{T2...}
cuf = cufunction(metaf, TT2)
res = cuf(ofn.val.f, args...; kwargs...)
end

return nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(A.val, x.val)
end

if A isa Duplicated || A isa DuplicatedNoNeed
ofn.val(A.dval, x isa Const ? zero(T) : x.dval)
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
ofn.val(A.dval[i], x isa Const ? zero(T) : x.dval[i])
nothing
end
end

if RT <: Duplicated
return A
elseif RT <: Const
return A.val
elseif RT <: DuplicatedNoNeed
return A.dval
elseif RT <: BatchDuplicated
return A
else
return A.dval
end
end


function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(A.val, x.val)
end

if !(T <: AbstractFloat)
if A isa Duplicated || A isa DuplicatedNoNeed
ofn.val(A.dval, zero(T))
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
ofn.val(A.dval[i], zero(T))
nothing
end
end
end

primal = if EnzymeRules.needs_primal(config)
A.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
A.dval
else
nothing
end
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, tape, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x::EnzymeCore.Annotation{T2}) where {RT, T <: CUDA.MemsetCompatTypes, T2}
dx = if x isa Active
if A isa Duplicated || A isa DuplicatedNoNeed
T2(sum(A.dval))
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
T2(sum(A.dval[i]))
end
end
else
nothing
end

# re-zero shadow
if (T <: AbstractFloat)
if A isa Duplicated || A isa DuplicatedNoNeed
ofn.val(A.dval, zero(T))
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
ofn.val(A.dval[i], zero(T))
nothing
end
end
end

return (nothing, dx)
end

end # module

3 changes: 3 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ function __init__()
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
include("../ext/SpecialFunctionsExt.jl")
end
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" begin
include("../ext/EnzymeCoreExt.jl")
end
end

# ensure that operations executed by the REPL back-end finish before returning,
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand Down
56 changes: 56 additions & 0 deletions test/libraries/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using EnzymeCore
using GPUCompiler
using Enzyme

@testset "compiler_job_from_backend" begin
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
end

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x)
@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end

@testset "Forward Kernel" begin
A = CUDA.rand(64)
dA = CUDA.ones(64)
A .= (1:1:64)
dA .= 1
Enzyme.autodiff(Forward, square!, Duplicated(A, dA))
@test all(dA .≈ (2:2:128))

A = CUDA.rand(32)
dA = CUDA.ones(32)
dA2 = CUDA.ones(32)
A .= (1:1:32)
dA .= 1
dA2 .= 3
Enzyme.autodiff(Forward, square!, BatchDuplicated(A, (dA, dA2)))
@test all(dA .≈ (2:2:64))
@test all(dA2 .≈ 3*(2:2:64))
end

@testset "Forward Fill!" begin
A = CUDA.ones(64)
dA = CUDA.ones(64)
Enzyme.autodiff(Forward, fill!, Duplicated(A, dA), Duplicated(2.0, 3.0))
@test all(A .≈ 2.0)
@test all(dA .≈ 3.0)
end

@testset "Reverse Fill!" begin
A = CUDA.zeros(64)
dA = CUDA.ones(64)
res = Enzyme.autodiff(Reverse, fill!, Const, Duplicated(A, dA), Active(1.0))[1][2]
@test res ≈ 64
@test all(A .≈ 1)
@test all(dA .≈ 0)
end
3 changes: 3 additions & 0 deletions test/setup.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using Distributed, Test, CUDA
using CUDA: i32

# ensure CUDA.jl is functional
@assert CUDA.functional(true)

# GPUArrays has a testsuite that isn't part of the main package.
# Include it directly.
import GPUArrays
Expand Down