From d72cdaa324cfadc1e67a7aa7fb9c4b035d2ec07c Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 4 Sep 2024 11:30:32 -0500 Subject: [PATCH] Add Enzyme `sum` derivatives (#2471) --- ext/EnzymeCoreExt.jl | 159 ++++++++++++++++++++++++++++++++++++++ test/extensions/enzyme.jl | 14 ++++ 2 files changed, 173 insertions(+) diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index ea52c27ca3..f8c8fe2c7d 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -12,6 +12,7 @@ else using ..EnzymeCore using ..EnzymeCore.EnzymeRules end +using GPUArrays function EnzymeCore.EnzymeRules.inactive(::typeof(CUDA.CUBLAS.handle)) return nothing @@ -516,5 +517,163 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...) return nothing end +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)}, + ::Type{RT}, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T} + if R isa Const || R isa Duplicated || R isa BatchDuplicated + ofn.val(f.val, op.val, R.val, A.val; init) + end + + if A isa Duplicated || A isa DuplicatedNoNeed + if A isa Const + Base.fill!(R.dval, zero(T)) + else + ofn.val(f.val, op.val, R.dval, A.dval) + end + elseif R isa BatchDuplicated || R isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(R))) do i + Base.@_inline_meta + if A isa Const + Base.fill!(R.dval[i], zero(T)) + else + ofn.val(f.val, op.val, R.dval[i], A.dval[i]) + end + nothing + end + end + + if RT <: Duplicated + return R + elseif RT <: Const + return R.val + elseif RT <: DuplicatedNoNeed + return R.dval + elseif RT <: BatchDuplicated + return R + else + return R.dval + end +end + + +function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays.mapreducedim!)}, + ::Type{RT}, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat} + if A isa Const || A isa Duplicated || A isa BatchDuplicated + ofn.val(f.val, op.val, R.val, A.val) + end + + primal = if EnzymeRules.needs_primal(config) + R.val + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + R.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapreducedim!)}, + ::Type{RT}, + tape, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat} + + if !(A isa Const) && !(R isa Const) + if A isa Duplicated || A isa DuplicatedNoNeed + A.dval .+= R.dval + Base.fill!(R.dval, zero(T)) + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + A.dval[i] .+= R.dval[i] + Base.fill!(R.dval[i], zero(T)) + nothing + end + end + end + + return (nothing, nothing, nothing, nothing) +end + +function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)}, + ::Type{RT}, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D} + if RT <: Const + ofn.val(f.val, op.val, A.val; dims, init) + elseif RT <: Duplicated + ( + ofn.val(f.val, op.val, A.val; dims, init), + ofn.val(f.val, op.val, A.dval; dims, init) + ) + elseif RT <: DuplicatedNoNeed + ofn.val(f.val, op.val, A.dval; dims, init) + elseif RT <: BatchDuplicated + ( + ofn.val(f.val, op.val, A.val; dims, init), + ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + ) + else + @assert RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + end +end + +function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays._mapreduce)}, + ::Type{Active{RT}}, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D} + primal = if EnzymeRules.needs_primal(config) + ofn.val(f.val, op.val, A.val; dims, init) + else + nothing + end + + shadow = if EnzymeRules.needs_shadow(config) + A.dval + else + nothing + end + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays._mapreduce)}, + dres::Active{RT}, + tape, + f::EnzymeCore.Const{typeof(Base.identity)}, + op::EnzymeCore.Const{typeof(Base.add_sum)}, + A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D} + + if A isa Duplicated || A isa DuplicatedNoNeed + A.dval .+= dres.val + elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.batch_width(A))) do i + Base.@_inline_meta + A.dval[i] .+= dres.val + nothing + end + end + + return (nothing, nothing, nothing) +end + end # module diff --git a/test/extensions/enzyme.jl b/test/extensions/enzyme.jl index 476448d4f6..0ad29c91eb 100644 --- a/test/extensions/enzyme.jl +++ b/test/extensions/enzyme.jl @@ -103,6 +103,20 @@ firstsum(x, y) = first(x .+ y) #@test res[2] ≈ 1.2 end +@testset "Forward sum" begin + x = CuArray([1.0, 2.0, 3.0, 4.0]) + dx = CuArray([100., 300.0, 500.0, 700.0]) + res = Enzyme.autodiff(Forward, sum, Duplicated(x, dx)) + @test res[1] ≈ 100+300+500+700. +end + +@testset "Reverse sum" begin + x = CuArray([1.0, 2.0, 3.0, 4.0]) + dx = CuArray([0., 0.0, 0.0, 0.0]) + Enzyme.autodiff(Reverse, sum, Duplicated(x, dx)) + @test all(dx .≈ 1.0) +end + # TODO once reverse kernels are in # function togpu(x) # x = CuArray(x)