Skip to content

Commit

Permalink
Add Enzyme sum derivatives (#2471)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 4, 2024
1 parent 9f93343 commit d72cdaa
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
159 changes: 159 additions & 0 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ else
using ..EnzymeCore
using ..EnzymeCore.EnzymeRules
end
using GPUArrays

function EnzymeCore.EnzymeRules.inactive(::typeof(CUDA.CUBLAS.handle))
return nothing
Expand Down Expand Up @@ -516,5 +517,163 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
return nothing
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T}
if R isa Const || R isa Duplicated || R isa BatchDuplicated
ofn.val(f.val, op.val, R.val, A.val; init)
end

if A isa Duplicated || A isa DuplicatedNoNeed
if A isa Const
Base.fill!(R.dval, zero(T))
else
ofn.val(f.val, op.val, R.dval, A.dval)
end
elseif R isa BatchDuplicated || R isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(R))) do i
Base.@_inline_meta
if A isa Const
Base.fill!(R.dval[i], zero(T))
else
ofn.val(f.val, op.val, R.dval[i], A.dval[i])
end
nothing
end
end

if RT <: Duplicated
return R
elseif RT <: Const
return R.val
elseif RT <: DuplicatedNoNeed
return R.dval
elseif RT <: BatchDuplicated
return R
else
return R.dval
end
end


function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}
if A isa Const || A isa Duplicated || A isa BatchDuplicated
ofn.val(f.val, op.val, R.val, A.val)
end

primal = if EnzymeRules.needs_primal(config)
R.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
R.dval
else
nothing
end
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
::Type{RT},
tape,
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}

if !(A isa Const) && !(R isa Const)
if A isa Duplicated || A isa DuplicatedNoNeed
A.dval .+= R.dval
Base.fill!(R.dval, zero(T))
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
A.dval[i] .+= R.dval[i]
Base.fill!(R.dval[i], zero(T))
nothing
end
end
end

return (nothing, nothing, nothing, nothing)
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)},
::Type{RT},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D}
if RT <: Const
ofn.val(f.val, op.val, A.val; dims, init)
elseif RT <: Duplicated
(
ofn.val(f.val, op.val, A.val; dims, init),
ofn.val(f.val, op.val, A.dval; dims, init)
)
elseif RT <: DuplicatedNoNeed
ofn.val(f.val, op.val, A.dval; dims, init)
elseif RT <: BatchDuplicated
(
ofn.val(f.val, op.val, A.val; dims, init),
ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
)
else
@assert RT <: BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(RT))) do i
Base.@_inline_meta
ofn.val(f.val, op.val, A.dval[i]; dims, init)
end
end
end

function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays._mapreduce)},
::Type{Active{RT}},
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}
primal = if EnzymeRules.needs_primal(config)
ofn.val(f.val, op.val, A.val; dims, init)
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
A.dval
else
nothing
end
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays._mapreduce)},
dres::Active{RT},
tape,
f::EnzymeCore.Const{typeof(Base.identity)},
op::EnzymeCore.Const{typeof(Base.add_sum)},
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}

if A isa Duplicated || A isa DuplicatedNoNeed
A.dval .+= dres.val
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.batch_width(A))) do i
Base.@_inline_meta
A.dval[i] .+= dres.val
nothing
end
end

return (nothing, nothing, nothing)
end

end # module

14 changes: 14 additions & 0 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ firstsum(x, y) = first(x .+ y)
#@test res[2] ≈ 1.2
end

@testset "Forward sum" begin
x = CuArray([1.0, 2.0, 3.0, 4.0])
dx = CuArray([100., 300.0, 500.0, 700.0])
res = Enzyme.autodiff(Forward, sum, Duplicated(x, dx))
@test res[1] 100+300+500+700.
end

@testset "Reverse sum" begin
x = CuArray([1.0, 2.0, 3.0, 4.0])
dx = CuArray([0., 0.0, 0.0, 0.0])
Enzyme.autodiff(Reverse, sum, Duplicated(x, dx))
@test all(dx .≈ 1.0)
end

# TODO once reverse kernels are in
# function togpu(x)
# x = CuArray(x)
Expand Down

0 comments on commit d72cdaa

Please sign in to comment.