Skip to content

Commit

Permalink
Enzyme: Mark CuArray as noalias (#2395)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored May 28, 2024
1 parent c927463 commit beff592
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.14"
ChainRulesCore = "1"
Crayons = "4"
DataFrames = "1"
EnzymeCore = "0.7.1"
EnzymeCore = "0.7.3"
ExprTools = "0.1"
GPUArrays = "10.0.1"
GPUCompiler = "0.24, 0.25, 0.26"
Expand Down
65 changes: 65 additions & 0 deletions ext/EnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,32 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
::Type{RT}, uval::EnzymeCore.Annotation{DR}, args...; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
end
if RT <: Duplicated
shadow = ofn.val(uval.val, primargs...; kwargs...)
Duplicated(ofn.val(uval.dval, primargs...; kwargs...), shadow)
elseif RT <: Const
ofn.val(uval.val, primargs...; kwargs...)
elseif RT <: DuplicatedNoNeed
ofn.val(uval.dval, primargs...; kwargs...)
else
tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i
Base.@_inline_meta
shadow = ofn.val(uval.dval[i], primargs...; kwargs...)
end
if RT <: BatchDuplicated
BatchDuplicated(ofv.val(uval.val), tup)
else
tup
end
end
end

function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)},
::Type{RT}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {RT, N}
pargs = ntuple(Val(N)) do i
Expand Down Expand Up @@ -269,5 +295,44 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}
nothing
end
end

function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{DR}, args...; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
end

primal = if EnzymeRules.needs_primal(config)
ofn.val(uval.val, primargs...; kwargs...)
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
ofn.val(uval.dval, primargs...; kwargs...)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ofn.val(uval.dval[i], primargs...; kwargs...)
end
end
else
nothing
end
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing)
end

function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{DR}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT, N}
ntuple(Val(N+1)) do i
Base.@_inline_meta
nothing
end
end

function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...) where {CT <: CuArray}
return nothing
end

end # module

12 changes: 12 additions & 0 deletions test/extensions/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ end
@test all(shad .≈ 0.0)
end

firstsum(x, y) = first(x .+ y)
@testset "Forward broadcast" begin
x = CuArray(5*ones(5))
y = CuArray(3*ones(5))
dx = CuArray([1.0, 0.0, 0.0, 0.0, 0.0])
dy = CuArray([0.2, 0.0, 0.1, 0.0, 0.0])
f(x, y)
res = autodiff(Forward, firstsum, Duplicated, Duplicated(x, dx), Duplicated(y, dy))
@test res[1] 8
@test res[2] 1.2
end

# TODO once reverse kernels are in
# function togpu(x)
# x = CuArray(x)
Expand Down

0 comments on commit beff592

Please sign in to comment.